未验证 提交 8a87d99e 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] unify tracker for single class and multi-calss MOT (#4403)

* add mcfairmot train eval export code

* fix eval and export

* update modelzoo

* add hrnet18_dlafpn

* unify jdetracker

* fix deploy infer for single class

* fix multi class fairmot deploy

* fix mcmot data source and deploy

* fix num_identities

* fix num_identities_dict
上级 56705e1e
metric: MCMOT
num_classes: 10
# using VisDrone2019 MOT dataset with 10 classes as default, you can modify it for your needs.
# for MCMOT training
TrainDataset:
!MCMOTDataSet
dataset_dir: dataset/mot
image_lists: ['visdrone_mcmot.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
label_list: label_list.txt
# for MCMOT evaluation
# If you want to change the MCMOT evaluation dataset, please modify 'data_root'
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: visdrone_mcmot/images/val
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
# for MCMOT video inference
TestMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video
...@@ -33,7 +33,6 @@ CenterNetHead: ...@@ -33,7 +33,6 @@ CenterNetHead:
FairMOTEmbeddingHead: FairMOTEmbeddingHead:
ch_head: 256 ch_head: 256
ch_emb: 128 ch_emb: 128
num_identifiers: 14455 # for mix dataset (Caltech, CityPersons, CUHK-SYSU, PRW, ETHZ and MOT16)
CenterNetPostProcess: CenterNetPostProcess:
max_per_img: 500 max_per_img: 500
......
English | [简体中文](README_cn.md)
# MCFairMOT (Multi-class FairMOT)
## Table of Contents
- [Introduction](#Introduction)
- [Model Zoo](#Model_Zoo)
- [Getting Start](#Getting_Start)
- [Citations](#Citations)
## Introduction
MCFairMOT is the Multi-class extended version of [FairMOT](https://arxiv.org/abs/2004.01888).
## Model Zoo
### MCFairMOT DLA-34 Results on VisDrone2019 Val Set
| backbone | input shape | MOTA | IDF1 | IDS | FPS | download | config |
| :--------------| :------- | :----: | :----: | :---: | :------: | :----: |:----: |
| DLA-34 | 1088x608 | 24.3 | 41.6 | 2314 | - |[model](https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams) | [config](./mcfairmot_dla34_30e_1088x608_visdrone.yml) |
| HRNetV2-W18 | 1088x608 | 20.4 | 39.9 | 2603 | - |[model](https://paddledet.bj.bcebos.com/models/mot/mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone.pdparams) | [config](./mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone.yml) |
**Notes:**
MOTA is the average MOTA of 10 catecories in the VisDrone2019 MOT dataset, and its value is also equal to the average MOTA of all the evaluated video sequences.
## Getting Start
### 1. Training
Training MCFairMOT on 8 GPUs with following command
```bash
python -m paddle.distributed.launch --log_dir=./mcfairmot_dla34_30e_1088x608_visdrone/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml
```
### 2. Evaluation
Evaluating the track performance of MCFairMOT on val dataset in single GPU with following commands:
```bash
# use weights released in PaddleDetection model zoo
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams
# use saved checkpoint in training
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=output/mcfairmot_dla34_30e_1088x608_visdrone/model_final.pdparams
```
**Notes:**
The default evaluation dataset is VisDrone2019 MOT val-set. If you want to change the evaluation dataset, please refer to the following code and modify `configs/datasets/mcmot.yml`
```
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: your_dataset/images/val
keep_ori_im: False # set True if save visualization images or video
```
### 3. Inference
Inference a vidoe on single GPU with following command:
```bash
# inference on video and save a video
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams --video_file={your video name}.mp4 --save_videos
```
**Notes:**
Please make sure that [ffmpeg](https://ffmpeg.org/ffmpeg.html) is installed first, on Linux(Ubuntu) platform you can directly install it by the following command:`apt-get update && apt-get install -y ffmpeg`.
### 4. Export model
```bash
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams
```
### 5. Using exported model for python inference
```bash
python deploy/python/mot_jde_infer.py --model_dir=output_inference/mcfairmot_dla34_30e_1088x608_visdrone --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**Notes:**
The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts` to save the txt result file, or `--save_images` to save the visualization images.
## Citations
```
@article{zhang2020fair,
title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking},
author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu},
journal={arXiv preprint arXiv:2004.01888},
year={2020}
}
@ARTICLE{9573394,
author={Zhu, Pengfei and Wen, Longyin and Du, Dawei and Bian, Xiao and Fan, Heng and Hu, Qinghua and Ling, Haibin},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={Detection and Tracking Meet Drones Challenge},
year={2021},
volume={},
number={},
pages={1-1},
doi={10.1109/TPAMI.2021.3119563}
}
```
简体中文 | [English](README.md)
# MCFairMOT (Multi-class FairMOT)
## 内容
- [简介](#简介)
- [模型库](#模型库)
- [快速开始](#快速开始)
- [引用](#引用)
## 内容
MCFairMOT是[FairMOT](https://arxiv.org/abs/2004.01888)的多类别扩展版本。
## 模型库
### MCFairMOT DLA-34 在VisDrone2019 MOT val-set上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :----: | :----: | :---: | :------: | :----: |:----: |
| DLA-34 | 1088x608 | 24.3 | 41.6 | 2314 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams) | [配置文件](./mcfairmot_dla34_30e_1088x608_visdrone.yml) |
| HRNetV2-W18 | 1088x608 | 20.4 | 39.9 | 2603 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone.pdparams) | [配置文件](./mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone.yml) |
**注意:**
MOTA是VisDrone2019 MOT数据集10类目标的平均MOTA, 其值也等于所有评估的视频序列的平均MOTA。
## 快速开始
### 1. 训练
使用8个GPU通过如下命令一键式启动训练
```bash
python -m paddle.distributed.launch --log_dir=./mcfairmot_dla34_30e_1088x608_visdrone/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml
```
### 2. 评估
使用单张GPU通过如下命令一键式启动评估
```bash
# 使用PaddleDetection发布的权重
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams
# 使用训练保存的checkpoint
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=output/mcfairmot_dla34_30e_1088x608_visdrone/model_final.pdparams
```
**注意:**
默认评估的是VisDrone2019 MOT val-set数据集, 如需换评估数据集可参照以下代码修改`configs/datasets/mcmot.yml`
```
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: your_dataset/images/val
keep_ori_im: False # set True if save visualization images or video
```
### 3. 预测
使用单个GPU通过如下命令预测一个视频,并保存为视频
```bash
# 预测一个视频
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams --video_file={your video name}.mp4 --save_videos
```
**注意:**
请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`
### 4. 导出预测模型
```bash
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone.pdparams
```
### 5. 用导出的模型基于Python去预测
```bash
python deploy/python/mot_jde_infer.py --model_dir=output_inference/mcfairmot_dla34_30e_1088x608_visdrone --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**注意:**
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
## 引用
```
@article{zhang2020fair,
title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking},
author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu},
journal={arXiv preprint arXiv:2004.01888},
year={2020}
}
@ARTICLE{9573394,
author={Zhu, Pengfei and Wen, Longyin and Du, Dawei and Bian, Xiao and Fan, Heng and Hu, Qinghua and Ling, Haibin},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={Detection and Tracking Meet Drones Challenge},
year={2021},
volume={},
number={},
pages={1-1},
doi={10.1109/TPAMI.2021.3119563}
}
```
_BASE_: [
'../fairmot/fairmot_dla34_30e_1088x608.yml',
'../../datasets/mcmot.yml'
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/fairmot_dla34_crowdhuman_pretrained.pdparams
FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker # multi-class tracker
CenterNetHead:
regress_ltrb: False
CenterNetPostProcess:
for_mot: True
regress_ltrb: False
max_per_img: 200
JDETracker:
min_box_area: 0
vertical_ratio: 0 # no need to filter bboxes according to w/h
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
weights: output/mcfairmot_dla34_30e_1088x608_visdrone/model_final
epoch: 30
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [10, 20]
use_warmup: False
OptimizerBuilder:
optimizer:
type: Adam
regularizer: NULL
_BASE_: [
'../fairmot/fairmot_hrnetv2_w18_dlafpn_30e_1088x608.yml',
'../../datasets/mcmot.yml'
]
architecture: FairMOT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/HRNet_W18_C_pretrained.pdparams
for_mot: True
FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker # multi-class tracker
CenterNetHead:
regress_ltrb: False
CenterNetPostProcess:
regress_ltrb: False
max_per_img: 200
JDETracker:
min_box_area: 0
vertical_ratio: 0 # no need to filter bboxes according to w/h
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
weights: output/mcfairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone/model_final
epoch: 30
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [10, 20]
use_warmup: False
OptimizerBuilder:
optimizer:
type: Adam
regularizer: NULL
TrainReader:
batch_size: 8
...@@ -17,18 +17,20 @@ import time ...@@ -17,18 +17,20 @@ import time
import yaml import yaml
import cv2 import cv2
import numpy as np import numpy as np
import paddle from collections import defaultdict
from benchmark_utils import PaddleInferBenchmark
from preprocess import preprocess
from tracker import JDETracker
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as MOTTimer
import paddle
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from preprocess import preprocess
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, PredictConfig from infer import Detector, get_test_images, print_arguments, PredictConfig
from benchmark_utils import PaddleInferBenchmark
from ppdet.modeling.mot.tracker import JDETracker
from ppdet.modeling.mot.visualization import plot_tracking_dict
from ppdet.modeling.mot.utils import MOTTimer, write_mot_results
# Global dictionary # Global dictionary
MOT_SUPPORT_MODELS = { MOT_SUPPORT_MODELS = {
...@@ -80,13 +82,17 @@ class JDE_Detector(Detector): ...@@ -80,13 +82,17 @@ class JDE_Detector(Detector):
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now" assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
assert pred_config.tracker, "Tracking model should have tracker" assert pred_config.tracker, "Tracking model should have tracker"
self.num_classes = len(pred_config.labels)
tp = pred_config.tracker tp = pred_config.tracker
min_box_area = tp['min_box_area'] if 'min_box_area' in tp else 200 min_box_area = tp['min_box_area'] if 'min_box_area' in tp else 200
vertical_ratio = tp['vertical_ratio'] if 'vertical_ratio' in tp else 1.6 vertical_ratio = tp['vertical_ratio'] if 'vertical_ratio' in tp else 1.6
conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0. conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0.
tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7 tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7
metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean' metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean'
self.tracker = JDETracker( self.tracker = JDETracker(
num_classes=self.num_classes,
min_box_area=min_box_area, min_box_area=min_box_area,
vertical_ratio=vertical_ratio, vertical_ratio=vertical_ratio,
conf_thres=conf_thres, conf_thres=conf_thres,
...@@ -94,25 +100,25 @@ class JDE_Detector(Detector): ...@@ -94,25 +100,25 @@ class JDE_Detector(Detector):
metric_type=metric_type) metric_type=metric_type)
def postprocess(self, pred_dets, pred_embs, threshold): def postprocess(self, pred_dets, pred_embs, threshold):
online_targets = self.tracker.update(pred_dets, pred_embs) online_targets_dict = self.tracker.update(pred_dets, pred_embs)
if online_targets == []:
# First few frames, the model may have no tracking results but have online_tlwhs = defaultdict(list)
# detection results,use the detection results instead, and set id -1. online_scores = defaultdict(list)
return [pred_dets[0][:4]], [pred_dets[0][4]], [-1] online_ids = defaultdict(list)
online_tlwhs, online_ids = [], [] for cls_id in range(self.num_classes):
online_scores = [] online_targets = online_targets_dict[cls_id]
for t in online_targets: for t in online_targets:
tlwh = t.tlwh tlwh = t.tlwh
tid = t.track_id tid = t.track_id
tscore = t.score tscore = t.score
if tscore < threshold: continue if tscore < threshold: continue
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio: 3] > self.tracker.vertical_ratio:
continue continue
online_tlwhs.append(tlwh) online_tlwhs[cls_id].append(tlwh)
online_ids.append(tid) online_ids[cls_id].append(tid)
online_scores.append(tscore) online_scores[cls_id].append(tscore)
return online_tlwhs, online_scores, online_ids return online_tlwhs, online_scores, online_ids
def predict(self, image_list, threshold=0.5, warmup=0, repeats=1): def predict(self, image_list, threshold=0.5, warmup=0, repeats=1):
...@@ -121,7 +127,7 @@ class JDE_Detector(Detector): ...@@ -121,7 +127,7 @@ class JDE_Detector(Detector):
image_list (list): list of image image_list (list): list of image
threshold (float): threshold of predicted box' score threshold (float): threshold of predicted box' score
Returns: Returns:
online_tlwhs, online_scores, online_ids (np.ndarray) online_tlwhs, online_scores, online_ids (dict[np.array])
''' '''
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list) inputs = self.preprocess(image_list)
...@@ -157,38 +163,12 @@ class JDE_Detector(Detector): ...@@ -157,38 +163,12 @@ class JDE_Detector(Detector):
return online_tlwhs, online_scores, online_ids return online_tlwhs, online_scores, online_ids
def write_mot_results(filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
x2=x2,
y2=y2,
w=w,
h=h,
score=score)
f.write(line)
def predict_image(detector, image_list): def predict_image(detector, image_list):
results = [] results = []
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
image_list.sort() image_list.sort()
for i, img_file in enumerate(image_list): for frame_id, img_file in enumerate(image_list):
frame = cv2.imread(img_file) frame = cv2.imread(img_file)
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
detector.predict([frame], FLAGS.threshold, warmup=10, repeats=10) detector.predict([frame], FLAGS.threshold, warmup=10, repeats=10)
...@@ -196,12 +176,12 @@ def predict_image(detector, image_list): ...@@ -196,12 +176,12 @@ def predict_image(detector, image_list):
detector.cpu_mem += cm detector.cpu_mem += cm
detector.gpu_mem += gm detector.gpu_mem += gm
detector.gpu_util += gu detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file)) print('Test iter {}, file name:{}'.format(frame_id, img_file))
else: else:
online_tlwhs, online_scores, online_ids = detector.predict( online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold) [frame], FLAGS.threshold)
online_im = mot_vis.plot_tracking( online_im = plot_tracking_dict(frame, num_classes, online_tlwhs,
frame, online_tlwhs, online_ids, online_scores, frame_id=i) online_ids, online_scores, frame_id)
if FLAGS.save_images: if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
...@@ -233,7 +213,9 @@ def predict_video(detector, camera_id): ...@@ -233,7 +213,9 @@ def predict_video(detector, camera_id):
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer = MOTTimer() timer = MOTTimer()
results = [] results = defaultdict(list) # support single class and multi classes
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
...@@ -243,10 +225,14 @@ def predict_video(detector, camera_id): ...@@ -243,10 +225,14 @@ def predict_video(detector, camera_id):
[frame], FLAGS.threshold) [frame], FLAGS.threshold)
timer.toc() timer.toc()
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids)) for cls_id in range(num_classes):
results[cls_id].append((frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
fps = 1. / timer.average_time fps = 1. / timer.average_time
im = mot_vis.plot_tracking( im = plot_tracking_dict(
frame, frame,
num_classes,
online_tlwhs, online_tlwhs,
online_ids, online_ids,
online_scores, online_scores,
...@@ -261,14 +247,6 @@ def predict_video(detector, camera_id): ...@@ -261,14 +247,6 @@ def predict_video(detector, camera_id):
else: else:
writer.write(im) writer.write(im)
if FLAGS.save_mot_txt_per_img:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
result_filename = os.path.join(save_dir,
'{:05d}.txt'.format(frame_id))
write_mot_results(result_filename, [results[-1]])
frame_id += 1 frame_id += 1
print('detect frame: %d' % (frame_id)) print('detect frame: %d' % (frame_id))
if camera_id != -1: if camera_id != -1:
...@@ -278,7 +256,8 @@ def predict_video(detector, camera_id): ...@@ -278,7 +256,8 @@ def predict_video(detector, camera_id):
if FLAGS.save_mot_txts: if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir, result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt') video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
write_mot_results(result_filename, results, data_type, num_classes)
if FLAGS.save_images: if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
......
...@@ -15,23 +15,23 @@ ...@@ -15,23 +15,23 @@
import os import os
import cv2 import cv2
import math import math
import copy
import numpy as np import numpy as np
from collections import defaultdict
import paddle import paddle
import copy
from mot_keypoint_unite_utils import argsparser from utils import get_current_memory_mb
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint from infer import Detector, PredictConfig, print_arguments, get_test_images
from visualize import draw_pose from visualize import draw_pose
from benchmark_utils import PaddleInferBenchmark
from utils import Timer
from tracker import JDETracker from mot_keypoint_unite_utils import argsparser
from mot_jde_infer import JDE_Detector, write_mot_results from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
from infer import Detector, PredictConfig, print_arguments, get_test_images
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as FPSTimer
from utils import get_current_memory_mb
from det_keypoint_unite_infer import predict_with_given_det, bench_log from det_keypoint_unite_infer import predict_with_given_det, bench_log
from mot_jde_infer import JDE_Detector
from ppdet.modeling.mot.visualization import plot_tracking_dict
from ppdet.modeling.mot.utils import MOTTimer as FPSTimer
from ppdet.modeling.mot.utils import write_mot_results
# Global dictionary # Global dictionary
KEYPOINT_SUPPORT_MODELS = { KEYPOINT_SUPPORT_MODELS = {
...@@ -56,6 +56,9 @@ def mot_keypoint_unite_predict_image(mot_model, ...@@ -56,6 +56,9 @@ def mot_keypoint_unite_predict_image(mot_model,
keypoint_model, keypoint_model,
image_list, image_list,
keypoint_batch_size=1): keypoint_batch_size=1):
num_classes = mot_model.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
data_type = 'mot'
image_list.sort() image_list.sort()
for i, img_file in enumerate(image_list): for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file) frame = cv2.imread(img_file)
...@@ -104,9 +107,13 @@ def mot_keypoint_unite_predict_image(mot_model, ...@@ -104,9 +107,13 @@ def mot_keypoint_unite_predict_image(mot_model,
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown'
else None) else None)
online_im = mot_vis.plot_tracking( online_im = plot_tracking_dict(
im, online_tlwhs, online_ids, online_scores, frame_id=i) im,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=i)
if FLAGS.save_images: if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
...@@ -143,7 +150,13 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -143,7 +150,13 @@ def mot_keypoint_unite_predict_video(mot_model,
timer_mot = FPSTimer() timer_mot = FPSTimer()
timer_kp = FPSTimer() timer_kp = FPSTimer()
timer_mot_kp = FPSTimer() timer_mot_kp = FPSTimer()
mot_results = []
# support single class and multi classes, but should be single class here
mot_results = defaultdict(list)
num_classes = mot_model.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
data_type = 'mot'
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
...@@ -153,15 +166,15 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -153,15 +166,15 @@ def mot_keypoint_unite_predict_video(mot_model,
online_tlwhs, online_scores, online_ids = mot_model.predict( online_tlwhs, online_scores, online_ids = mot_model.predict(
[frame], FLAGS.mot_threshold) [frame], FLAGS.mot_threshold)
timer_mot.toc() timer_mot.toc()
mot_results.append( mot_results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids)) (frame_id + 1, online_tlwhs[0], online_scores[0], online_ids[0]))
mot_fps = 1. / timer_mot.average_time mot_fps = 1. / timer_mot.average_time
timer_kp.tic() timer_kp.tic()
keypoint_arch = keypoint_model.pred_config.arch keypoint_arch = keypoint_model.pred_config.arch
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown': if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown':
results = convert_mot_to_det(online_tlwhs, online_scores) results = convert_mot_to_det(online_tlwhs[0], online_scores[0])
keypoint_results = predict_with_given_det( keypoint_results = predict_with_given_det(
frame, results, keypoint_model, keypoint_batch_size, frame, results, keypoint_model, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.mot_threshold, FLAGS.keypoint_threshold,
...@@ -184,8 +197,9 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -184,8 +197,9 @@ def mot_keypoint_unite_predict_video(mot_model,
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' else if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' else
None) None)
online_im = mot_vis.plot_tracking( online_im = plot_tracking_dict(
im, im,
num_classes,
online_tlwhs, online_tlwhs,
online_ids, online_ids,
online_scores, online_scores,
...@@ -212,7 +226,7 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -212,7 +226,7 @@ def mot_keypoint_unite_predict_video(mot_model,
if FLAGS.save_mot_txts: if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir, result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt') video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, mot_results) write_mot_results(result_filename, mot_results, data_type, num_classes)
if FLAGS.save_images: if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
......
...@@ -22,7 +22,7 @@ from benchmark_utils import PaddleInferBenchmark ...@@ -22,7 +22,7 @@ from benchmark_utils import PaddleInferBenchmark
from preprocess import preprocess from preprocess import preprocess
from tracker import DeepSORTTracker from tracker import DeepSORTTracker
from ppdet.modeling.mot import visualization as mot_vis from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as MOTTimer from ppdet.modeling.mot.utils import MOTTimer
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import jde_tracker
from . import deepsort_tracker from . import deepsort_tracker
from .jde_tracker import *
from .deepsort_tracker import * from .deepsort_tracker import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np
from ppdet.modeling.mot.matching import jde_matching as matching
from ppdet.modeling.mot.motion import KalmanFilter
from ppdet.modeling.mot.tracker.base_jde_tracker import TrackState, BaseTrack, STrack
from ppdet.modeling.mot.tracker.base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
__all__ = ['JDETracker']
class JDETracker(object):
"""
JDE tracker
Args:
det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set -1
means no need to filter bboxes.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
r_tracked_thresh (float): linear assignment threshold of
tracked stracks and unmatched detections
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
motion (object): KalmanFilter instance
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def __init__(self,
det_thresh=0.3,
track_buffer=30,
min_box_area=200,
vertical_ratio=1.6,
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
motion='KalmanFilter',
conf_thres=0,
metric_type='euclidean'):
self.det_thresh = det_thresh
self.track_buffer = track_buffer
self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh
self.motion = KalmanFilter()
self.conf_thres = conf_thres
self.metric_type = metric_type
self.frame_id = 0
self.tracked_stracks = []
self.lost_stracks = []
self.removed_stracks = []
self.max_time_lost = 0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def update(self, pred_dets, pred_embs):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles
lost, removed, refound and active tracklets.
Args:
pred_dets (Tensor): Detection results of the image, shape is [N, 5].
pred_embs (Tensor): Embedding results of the image, shape is [N, 512].
Return:
output_stracks (list): The list contains information regarding the
online_tracklets for the recieved image tensor.
"""
self.frame_id += 1
activated_starcks = []
# for storing active tracks, for the current frame
refind_stracks = []
# Lost Tracks whose detections are obtained in the current frame
lost_stracks = []
# The tracks which are not obtained in the current frame but are not
# removed. (Lost for some time lesser than the threshold for removing)
removed_stracks = []
remain_inds = np.nonzero(pred_dets[:, 4] > self.conf_thres)
if len(remain_inds) == 0:
pred_dets = np.zeros([0, 1])
pred_embs = np.zeros([0, 1])
else:
pred_dets = pred_dets[remain_inds]
pred_embs = pred_embs[remain_inds]
# Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]]
empty_pred = True if len(pred_dets) == 1 and np.sum(
pred_dets) == 0.0 else False
""" Step 1: Network forward, get detections & embeddings"""
if len(pred_dets) > 0 and not empty_pred:
detections = [
STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30)
for (tlbrs, f) in zip(pred_dets, pred_embs)
]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:
if not track.is_activated:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed.append(track)
else:
# Active tracks are added to the local list 'tracked_stracks'
tracked_stracks.append(track)
""" Step 2: First association, with embedding"""
# Combining currently tracked_stracks and lost_stracks
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool, self.motion)
dists = matching.embedding_distance(
strack_pool, detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists, strack_pool,
detections)
# The dists is the list of distances of the detection with the tracks in strack_pool
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.tracked_thresh)
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
for itracked, idet in matches:
# itracked is the id of the track and idet is the detection
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
# If the track is active, add the detection to the track
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
# We have obtained a detection from a track which is not active,
# hence put the track in refind_stracks list
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU"""
detections = [detections[i] for i in u_detection]
# detections is now a list of the unmatched detections
r_tracked_stracks = []
# This is container for stracks which were tracked till the previous
# frame but no detection was found for it in the current frame.
for i in u_track:
if strack_pool[i].state == TrackState.Tracked:
r_tracked_stracks.append(strack_pool[i])
dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.r_tracked_thresh)
# matches is the list of detections which matched with corresponding
# tracks by IOU distance method.
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# Same process done for some unmatched detections, but now considering IOU_distance as measure
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(
dists, thresh=self.unconfirmed_thresh)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
# The tracks which are yet not matched
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
removed_stracks.append(track)
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.det_thresh:
continue
track.activate(self.motion, self.frame_id)
activated_starcks.append(track)
""" Step 5: Update state"""
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_stracks.append(track)
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
self.tracked_stracks = [
t for t in self.tracked_stracks if t.state == TrackState.Tracked
]
self.tracked_stracks = joint_stracks(self.tracked_stracks,
activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks,
refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
self.tracked_stracks, self.lost_stracks)
# get scores of lost tracks
output_stracks = [
track for track in self.tracked_stracks if track.is_activated
]
return output_stracks
...@@ -90,9 +90,12 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -90,9 +90,12 @@ def get_categories(metric_type, anno_file=None, arch=None):
elif metric_type.lower() in ['mot', 'motdet', 'reid']: elif metric_type.lower() in ['mot', 'motdet', 'reid']:
return _mot_category() return _mot_category()
elif metric_type.lower() in ['kitti', 'bdd100k']: elif metric_type.lower() in ['kitti', 'bdd100kmot']:
return _mot_category(category='car') return _mot_category(category='car')
elif metric_type.lower() in ['mcmot']:
return _visdrone_category()
else: else:
raise ValueError("unknown metric type {}".format(metric_type)) raise ValueError("unknown metric type {}".format(metric_type))
...@@ -825,3 +828,21 @@ def _oid19_category(): ...@@ -825,3 +828,21 @@ def _oid19_category():
} }
return clsid2catid, catid2name return clsid2catid, catid2name
def _visdrone_category():
clsid2catid = {i: i for i in range(10)}
catid2name = {
0: 'pedestrian',
1: 'people',
2: 'bicycle',
3: 'car',
4: 'van',
5: 'truck',
6: 'tricycle',
7: 'awning-tricycle',
8: 'bus',
9: 'motor'
}
return clsid2catid, catid2name
...@@ -17,7 +17,7 @@ import sys ...@@ -17,7 +17,7 @@ import sys
import cv2 import cv2
import glob import glob
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict, defaultdict
try: try:
from collections.abc import Sequence from collections.abc import Sequence
except Exception: except Exception:
...@@ -32,7 +32,8 @@ logger = setup_logger(__name__) ...@@ -32,7 +32,8 @@ logger = setup_logger(__name__)
@serializable @serializable
class MOTDataSet(DetDataset): class MOTDataSet(DetDataset):
""" """
Load dataset with MOT format. Load dataset with MOT format, only support single class MOT.
Args: Args:
dataset_dir (str): root directory for dataset. dataset_dir (str): root directory for dataset.
image_lists (str|list): mot data image lists, muiti-source mot dataset. image_lists (str|list): mot data image lists, muiti-source mot dataset.
...@@ -152,18 +153,17 @@ class MOTDataSet(DetDataset): ...@@ -152,18 +153,17 @@ class MOTDataSet(DetDataset):
self.tid_start_index[k] = last_index self.tid_start_index[k] = last_index
last_index += v last_index += v
self.total_identities = int(last_index + 1) self.num_identities_dict = defaultdict(int)
self.num_identities_dict[0] = int(last_index + 1) # single class
self.num_imgs_each_data = [len(x) for x in self.img_files.values()] self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data) self.total_imgs = sum(self.num_imgs_each_data)
logger.info('=' * 80)
logger.info('MOT dataset summary: ') logger.info('MOT dataset summary: ')
logger.info(self.tid_num) logger.info(self.tid_num)
logger.info('total images: {}'.format(self.total_imgs)) logger.info('Total images: {}'.format(self.total_imgs))
logger.info('image start index: {}'.format(self.img_start_index)) logger.info('Image start index: {}'.format(self.img_start_index))
logger.info('total identities: {}'.format(self.total_identities)) logger.info('Total identities: {}'.format(self.num_identities_dict[0]))
logger.info('identity start index: {}'.format(self.tid_start_index)) logger.info('Identity start index: {}'.format(self.tid_start_index))
logger.info('=' * 80)
records = [] records = []
cname2cid = mot_label() cname2cid = mot_label()
...@@ -222,9 +222,223 @@ class MOTDataSet(DetDataset): ...@@ -222,9 +222,223 @@ class MOTDataSet(DetDataset):
self.roidbs, self.cname2cid = records, cname2cid self.roidbs, self.cname2cid = records, cname2cid
def mot_label(): @register
labels_map = {'person': 0} @serializable
return labels_map class MCMOTDataSet(DetDataset):
"""
Load dataset with MOT format, support multi-class MOT.
Args:
dataset_dir (str): root directory for dataset.
image_lists (list(str)): mcmot data image lists, muiti-source mcmot dataset.
data_fields (list): key name of data dictionary, at least have 'image'.
label_list (str): if use_default_label is False, will load
mapping between category and class index.
sample_num (int): number of samples to load, -1 means all.
Notes:
MCMOT datasets root directory following this:
dataset/mot
|——————image_lists
| |——————visdrone_mcmot.train
| |——————visdrone_mcmot.val
visdrone_mcmot
|——————images
| └——————train
| └——————val
└——————labels_with_ids
└——————train
"""
def __init__(self,
dataset_dir=None,
image_lists=[],
data_fields=['image'],
label_list=None,
sample_num=-1):
super(MCMOTDataSet, self).__init__(
dataset_dir=dataset_dir,
data_fields=data_fields,
sample_num=sample_num)
self.dataset_dir = dataset_dir
self.image_lists = image_lists
if isinstance(self.image_lists, str):
self.image_lists = [self.image_lists]
self.label_list = label_list
self.roidbs = None
self.cname2cid = None
def get_anno(self):
if self.image_lists == []:
return
# only used to get categories and metric
return os.path.join(self.dataset_dir, 'image_lists',
self.image_lists[0])
def parse_dataset(self):
self.img_files = OrderedDict()
self.img_start_index = OrderedDict()
self.label_files = OrderedDict()
self.tid_num = OrderedDict()
self.tid_start_idx_of_cls_ids = defaultdict(dict) # for MCMOT
img_index = 0
for data_name in self.image_lists:
# check every data image list
image_lists_dir = os.path.join(self.dataset_dir, 'image_lists')
assert os.path.isdir(image_lists_dir), \
"The {} is not a directory.".format(image_lists_dir)
list_path = os.path.join(image_lists_dir, data_name)
assert os.path.exists(list_path), \
"The list path {} does not exist.".format(list_path)
# record img_files, filter out empty ones
with open(list_path, 'r') as file:
self.img_files[data_name] = file.readlines()
self.img_files[data_name] = [
os.path.join(self.dataset_dir, x.strip())
for x in self.img_files[data_name]
]
self.img_files[data_name] = list(
filter(lambda x: len(x) > 0, self.img_files[data_name]))
self.img_start_index[data_name] = img_index
img_index += len(self.img_files[data_name])
# record label_files
self.label_files[data_name] = [
x.replace('images', 'labels_with_ids').replace(
'.png', '.txt').replace('.jpg', '.txt')
for x in self.img_files[data_name]
]
for data_name, label_paths in self.label_files.items():
# using max_ids_dict rather than max_index
max_ids_dict = defaultdict(int)
for lp in label_paths:
lb = np.loadtxt(lp)
if len(lb) < 1:
continue
lb = lb.reshape(-1, 6)
for item in lb:
if item[1] > max_ids_dict[int(item[0])]:
# item[0]: cls_id
# item[1]: track id
max_ids_dict[int(item[0])] = int(item[1])
# track id number
self.tid_num[data_name] = max_ids_dict
last_idx_dict = defaultdict(int)
for i, (k, v) in enumerate(self.tid_num.items()): # each sub dataset
for cls_id, id_num in v.items(): # v is a max_ids_dict
self.tid_start_idx_of_cls_ids[k][cls_id] = last_idx_dict[cls_id]
last_idx_dict[cls_id] += id_num
self.num_identities_dict = defaultdict(int)
for k, v in last_idx_dict.items():
self.num_identities_dict[k] = int(v) # total ids of each category
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data)
# cname2cid and cid2cname
cname2cid = {}
if self.label_list:
# if use label_list for multi source mix dataset,
# please make sure label_list in the first sub_dataset at least.
sub_dataset = self.image_lists[0].split('.')[0]
label_path = os.path.join(self.dataset_dir, sub_dataset,
self.label_list)
if not os.path.exists(label_path):
raise ValueError("label_list {} does not exists".format(
label_path))
with open(label_path, 'r') as fr:
label_id = 0
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else:
cname2cid = visdrone_mcmot_label()
cid2cname = dict([(v, k) for (k, v) in cname2cid.items()])
logger.info('MCMOT dataset summary: ')
logger.info(self.tid_num)
logger.info('Total images: {}'.format(self.total_imgs))
logger.info('Image start index: {}'.format(self.img_start_index))
logger.info('Total identities of each category: ')
self.num_identities_dict = sorted(
self.num_identities_dict.items(), key=lambda x: x[0])
total_IDs_all_cats = 0
for (k, v) in self.num_identities_dict:
logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k],
v))
total_IDs_all_cats += v
logger.info('Total identities of all categories: {}'.format(
total_IDs_all_cats))
logger.info('Identity start index of each category: ')
for k, v in self.tid_start_idx_of_cls_ids.items():
sorted_v = sorted(v.items(), key=lambda x: x[0])
for (cls_id, start_idx) in sorted_v:
logger.info('Start index of dataset {} category {:d} is {:d}'
.format(k, cls_id, start_idx))
records = []
for img_index in range(self.total_imgs):
for i, (k, v) in enumerate(self.img_start_index.items()):
if img_index >= v:
data_name = list(self.label_files.keys())[i]
start_index = v
img_file = self.img_files[data_name][img_index - start_index]
lbl_file = self.label_files[data_name][img_index - start_index]
if not os.path.exists(img_file):
logger.warning('Illegal image file: {}, and it will be ignored'.
format(img_file))
continue
if not os.path.isfile(lbl_file):
logger.warning('Illegal label file: {}, and it will be ignored'.
format(lbl_file))
continue
labels = np.loadtxt(lbl_file, dtype=np.float32).reshape(-1, 6)
# each row in labels (N, 6) is [gt_class, gt_identity, cx, cy, w, h]
cx, cy = labels[:, 2], labels[:, 3]
w, h = labels[:, 4], labels[:, 5]
gt_bbox = np.stack((cx, cy, w, h)).T.astype('float32')
gt_class = labels[:, 0:1].astype('int32')
gt_score = np.ones((len(labels), 1)).astype('float32')
gt_ide = labels[:, 1:2].astype('int32')
for i, _ in enumerate(gt_ide):
if gt_ide[i] > -1:
cls_id = int(gt_class[i])
start_idx = self.tid_start_idx_of_cls_ids[data_name][cls_id]
gt_ide[i] += start_idx
mot_rec = {
'im_file': img_file,
'im_id': img_index,
} if 'image' in self.data_fields else {}
gt_rec = {
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_ide': gt_ide,
}
for k, v in gt_rec.items():
if k in self.data_fields:
mot_rec[k] = v
records.append(mot_rec)
if self.sample_num > 0 and img_index >= self.sample_num:
break
assert len(records) > 0, 'not found any mot record in %s' % (
self.image_lists)
self.roidbs, self.cname2cid = records, cname2cid
@register @register
...@@ -382,3 +596,24 @@ def video2frames(video_path, outpath, frame_rate, **kargs): ...@@ -382,3 +596,24 @@ def video2frames(video_path, outpath, frame_rate, **kargs):
sys.stdout.flush() sys.stdout.flush()
return out_full_path return out_full_path
def mot_label():
labels_map = {'person': 0}
return labels_map
def visdrone_mcmot_label():
labels_map = {
'pedestrian': 0,
'people': 1,
'bicycle': 2,
'car': 3,
'van': 4,
'truck': 5,
'tricycle': 6,
'awning-tricycle': 7,
'bus': 8,
'motor': 9,
}
return labels_map
...@@ -556,6 +556,11 @@ class Gt2FairMOTTarget(Gt2TTFTarget): ...@@ -556,6 +556,11 @@ class Gt2FairMOTTarget(Gt2TTFTarget):
index_mask = np.zeros((self.max_objs, ), dtype=np.int32) index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
reid = np.zeros((self.max_objs, ), dtype=np.int64) reid = np.zeros((self.max_objs, ), dtype=np.int64)
bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32) bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
if self.num_classes > 1:
# each category corresponds to a set of track ids
cls_tr_ids = np.zeros(
(self.num_classes, output_h, output_w), dtype=np.int64)
cls_id_map = np.full((output_h, output_w), -1, dtype=np.int64)
gt_bbox = sample['gt_bbox'] gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class'] gt_class = sample['gt_class']
...@@ -598,6 +603,10 @@ class Gt2FairMOTTarget(Gt2TTFTarget): ...@@ -598,6 +603,10 @@ class Gt2FairMOTTarget(Gt2TTFTarget):
index_mask[k] = 1 index_mask[k] = 1
reid[k] = ide reid[k] = ide
bbox_xys[k] = bbox_xy bbox_xys[k] = bbox_xy
if self.num_classes > 1:
cls_id_map[ct_int[1], ct_int[0]] = cls_id
cls_tr_ids[cls_id][ct_int[1]][ct_int[0]] = ide - 1
# track id start from 0
sample['heatmap'] = heatmap sample['heatmap'] = heatmap
sample['index'] = index sample['index'] = index
...@@ -605,6 +614,9 @@ class Gt2FairMOTTarget(Gt2TTFTarget): ...@@ -605,6 +614,9 @@ class Gt2FairMOTTarget(Gt2TTFTarget):
sample['size'] = bbox_size sample['size'] = bbox_size
sample['index_mask'] = index_mask sample['index_mask'] = index_mask
sample['reid'] = reid sample['reid'] = reid
if self.num_classes > 1:
sample['cls_id_map'] = cls_id_map
sample['cls_tr_ids'] = cls_tr_ids
sample['bbox_xys'] = bbox_xys sample['bbox_xys'] = bbox_xys
sample.pop('is_crowd', None) sample.pop('is_crowd', None)
sample.pop('difficult', None) sample.pop('difficult', None)
......
...@@ -21,14 +21,15 @@ import cv2 ...@@ -21,14 +21,15 @@ import cv2
import glob import glob
import paddle import paddle
import numpy as np import numpy as np
from collections import defaultdict
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import Timer, load_det_results from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric
from ppdet.metrics import MCMOTMetric
import ppdet.utils.stats as stats import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback from .callbacks import Callback, ComposeCallback
...@@ -74,6 +75,8 @@ class Tracker(object): ...@@ -74,6 +75,8 @@ class Tracker(object):
if self.cfg.metric == 'MOT': if self.cfg.metric == 'MOT':
self._metrics = [MOTMetric(), ] self._metrics = [MOTMetric(), ]
elif self.cfg.metric == 'MCMOT':
self._metrics = [MCMOTMetric(self.cfg.num_classes), ]
elif self.cfg.metric == 'KITTI': elif self.cfg.metric == 'KITTI':
self._metrics = [KITTIMOTMetric(), ] self._metrics = [KITTIMOTMetric(), ]
else: else:
...@@ -121,43 +124,49 @@ class Tracker(object): ...@@ -121,43 +124,49 @@ class Tracker(object):
tracker = self.model.tracker tracker = self.model.tracker
tracker.max_time_lost = int(frame_rate / 30.0 * tracker.track_buffer) tracker.max_time_lost = int(frame_rate / 30.0 * tracker.track_buffer)
timer = Timer() timer = MOTTimer()
results = []
frame_id = 0 frame_id = 0
self.status['mode'] = 'track' self.status['mode'] = 'track'
self.model.eval() self.model.eval()
results = defaultdict(list) # support single class and multi classes
for step_id, data in enumerate(dataloader): for step_id, data in enumerate(dataloader):
self.status['step_id'] = step_id self.status['step_id'] = step_id
if frame_id % 40 == 0: if frame_id % 40 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format( logger.info('Processing frame {} ({:.2f} fps)'.format(
frame_id, 1. / max(1e-5, timer.average_time))) frame_id, 1. / max(1e-5, timer.average_time)))
# forward # forward
timer.tic() timer.tic()
pred_dets, pred_embs = self.model(data) pred_dets, pred_embs = self.model(data)
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = [], [], [] pred_dets, pred_embs = pred_dets.numpy(), pred_embs.numpy()
for t in online_targets: online_targets_dict = self.model.tracker.update(pred_dets,
tlwh = t.tlwh pred_embs)
tid = t.track_id online_tlwhs = defaultdict(list)
tscore = t.score online_scores = defaultdict(list)
if tscore < draw_threshold: continue online_ids = defaultdict(list)
if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue for cls_id in range(self.cfg.num_classes):
if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ online_targets = online_targets_dict[cls_id]
3] > tracker.vertical_ratio: for t in online_targets:
continue tlwh = t.tlwh
online_tlwhs.append(tlwh) tid = t.track_id
online_ids.append(tid) tscore = t.score
online_scores.append(tscore) if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
timer.toc() if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio:
continue
online_tlwhs[cls_id].append(tlwh)
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
# save results
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
# save results timer.toc()
results.append( save_vis_results(data, frame_id, online_ids, online_tlwhs,
(frame_id + 1, online_tlwhs, online_scores, online_ids)) online_scores, timer.average_time, show_image,
self.save_results(data, frame_id, online_ids, online_tlwhs, save_dir, self.cfg.num_classes)
online_scores, timer.average_time, show_image,
save_dir)
frame_id += 1 frame_id += 1
return results, frame_id, timer.average_time, timer.calls return results, frame_id, timer.average_time, timer.calls
...@@ -174,7 +183,7 @@ class Tracker(object): ...@@ -174,7 +183,7 @@ class Tracker(object):
if not os.path.exists(save_dir): os.makedirs(save_dir) if not os.path.exists(save_dir): os.makedirs(save_dir)
use_detector = False if not self.model.detector else True use_detector = False if not self.model.detector else True
timer = Timer() timer = MOTTimer()
results = [] results = []
frame_id = 0 frame_id = 0
self.status['mode'] = 'track' self.status['mode'] = 'track'
...@@ -284,9 +293,9 @@ class Tracker(object): ...@@ -284,9 +293,9 @@ class Tracker(object):
# save results # save results
results.append( results.append(
(frame_id + 1, online_tlwhs, online_scores, online_ids)) (frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs, save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image, online_scores, timer.average_time, show_image,
save_dir) save_dir, self.cfg.num_classes)
frame_id += 1 frame_id += 1
return results, frame_id, timer.average_time, timer.calls return results, frame_id, timer.average_time, timer.calls
...@@ -305,37 +314,39 @@ class Tracker(object): ...@@ -305,37 +314,39 @@ class Tracker(object):
if not os.path.exists(output_dir): os.makedirs(output_dir) if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results') result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root) if not os.path.exists(result_root): os.makedirs(result_root)
assert data_type in ['mot', 'kitti'], \ assert data_type in ['mot', 'mcmot', 'kitti'], \
"data_type should be 'mot' or 'kitti'" "data_type should be 'mot', 'mcmot' or 'kitti'"
assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \ assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
"model_type should be 'JDE', 'DeepSORT' or 'FairMOT'" "model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
# run tracking # run tracking
n_frame = 0 n_frame = 0
timer_avgs, timer_calls = [], [] timer_avgs, timer_calls = [], []
for seq in seqs: for seq in seqs:
if not os.path.isdir(os.path.join(data_root, seq)): infer_dir = os.path.join(data_root, seq)
if not os.path.exists(infer_dir) or not os.path.isdir(infer_dir):
logger.warning("Seq {} error, {} has no images.".format(
seq, infer_dir))
continue continue
infer_dir = os.path.join(data_root, seq, 'img1') if os.path.exists(os.path.join(infer_dir, 'img1')):
infer_dir = os.path.join(infer_dir, 'img1')
frame_rate = 30
seqinfo = os.path.join(data_root, seq, 'seqinfo.ini') seqinfo = os.path.join(data_root, seq, 'seqinfo.ini')
if not os.path.exists(seqinfo) or not os.path.exists( if os.path.exists(seqinfo):
infer_dir) or not os.path.isdir(infer_dir): meta_info = open(seqinfo).read()
continue frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
meta_info.find('\nseqLength')])
save_dir = os.path.join(output_dir, 'mot_outputs', save_dir = os.path.join(output_dir, 'mot_outputs',
seq) if save_images or save_videos else None seq) if save_images or save_videos else None
logger.info('start seq: {}'.format(seq)) logger.info('start seq: {}'.format(seq))
images = self.get_infer_images(infer_dir) self.dataset.set_images(self.get_infer_images(infer_dir))
self.dataset.set_images(images)
dataloader = create('EvalMOTReader')(self.dataset, 0) dataloader = create('EvalMOTReader')(self.dataset, 0)
result_filename = os.path.join(result_root, '{}.txt'.format(seq)) result_filename = os.path.join(result_root, '{}.txt'.format(seq))
meta_info = open(seqinfo).read()
frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
meta_info.find('\nseqLength')])
with paddle.no_grad(): with paddle.no_grad():
if model_type in ['JDE', 'FairMOT']: if model_type in ['JDE', 'FairMOT']:
results, nf, ta, tc = self._eval_seq_jde( results, nf, ta, tc = self._eval_seq_jde(
...@@ -355,7 +366,8 @@ class Tracker(object): ...@@ -355,7 +366,8 @@ class Tracker(object):
else: else:
raise ValueError(model_type) raise ValueError(model_type)
self.write_mot_results(result_filename, results, data_type) write_mot_results(result_filename, results, data_type,
self.cfg.num_classes)
n_frame += nf n_frame += nf
timer_avgs.append(ta) timer_avgs.append(ta)
timer_calls.append(tc) timer_calls.append(tc)
...@@ -427,8 +439,8 @@ class Tracker(object): ...@@ -427,8 +439,8 @@ class Tracker(object):
if not os.path.exists(output_dir): os.makedirs(output_dir) if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results') result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root) if not os.path.exists(result_root): os.makedirs(result_root)
assert data_type in ['mot', 'kitti'], \ assert data_type in ['mot', 'mcmot', 'kitti'], \
"data_type should be 'mot' or 'kitti'" "data_type should be 'mot', 'mcmot' or 'kitti'"
assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \ assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
"model_type should be 'JDE', 'DeepSORT' or 'FairMOT'" "model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
...@@ -478,7 +490,8 @@ class Tracker(object): ...@@ -478,7 +490,8 @@ class Tracker(object):
else: else:
raise ValueError(model_type) raise ValueError(model_type)
self.write_mot_results(result_filename, results, data_type) write_mot_results(result_filename, results, data_type,
self.cfg.num_classes)
if save_videos: if save_videos:
output_video_path = os.path.join(save_dir, '..', output_video_path = os.path.join(save_dir, '..',
...@@ -487,52 +500,3 @@ class Tracker(object): ...@@ -487,52 +500,3 @@ class Tracker(object):
save_dir, output_video_path) save_dir, output_video_path)
os.system(cmd_str) os.system(cmd_str)
logger.info('Save video in {}'.format(output_video_path)) logger.info('Save video in {}'.format(output_video_path))
def write_mot_results(self, filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
x2=x2,
y2=y2,
w=w,
h=h,
score=score)
f.write(line)
logger.info('MOT results save in {}'.format(filename))
def save_results(self, data, frame_id, online_ids, online_tlwhs,
online_scores, average_time, show_image, save_dir):
if show_image or save_dir is not None:
assert 'ori_image' in data
img0 = data['ori_image'].numpy()[0]
online_im = mot_vis.plot_tracking(
img0,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / average_time)
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im)
...@@ -41,7 +41,7 @@ from ppdet.data.source.category import get_categories ...@@ -41,7 +41,7 @@ from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats import ppdet.utils.stats as stats
from ppdet.utils import profiler from ppdet.utils import profiler
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter,SniperProposalsGenerator from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
from .export_utils import _dump_infer_config, _prune_input_spec from .export_utils import _dump_infer_config, _prune_input_spec
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
...@@ -77,11 +77,12 @@ class Trainer(object): ...@@ -77,11 +77,12 @@ class Trainer(object):
if cfg.architecture == 'JDE' and self.mode == 'train': if cfg.architecture == 'JDE' and self.mode == 'train':
cfg['JDEEmbeddingHead'][ cfg['JDEEmbeddingHead'][
'num_identifiers'] = self.dataset.total_identities 'num_identities'] = self.dataset.num_identities_dict[0]
# JDE only support single class MOT now.
if cfg.architecture == 'FairMOT' and self.mode == 'train': if cfg.architecture == 'FairMOT' and self.mode == 'train':
cfg['FairMOTEmbeddingHead'][ cfg['FairMOTEmbeddingHead']['num_identities_dict'] = self.dataset.num_identities_dict
'num_identifiers'] = self.dataset.total_identities # FairMOT support single class and multi-class MOT now.
# build model # build model
if 'model' not in self.cfg: if 'model' not in self.cfg:
...@@ -192,7 +193,7 @@ class Trainer(object): ...@@ -192,7 +193,7 @@ class Trainer(object):
IouType=IouType, IouType=IouType,
save_prediction_only=save_prediction_only) save_prediction_only=save_prediction_only)
] ]
elif self.cfg.metric == "SNIPERCOCO": # sniper elif self.cfg.metric == "SNIPERCOCO": # sniper
self._metrics = [ self._metrics = [
SNIPERCOCOMetric( SNIPERCOCOMetric(
anno_file=anno_file, anno_file=anno_file,
...@@ -202,8 +203,7 @@ class Trainer(object): ...@@ -202,8 +203,7 @@ class Trainer(object):
output_eval=output_eval, output_eval=output_eval,
bias=bias, bias=bias,
IouType=IouType, IouType=IouType,
save_prediction_only=save_prediction_only save_prediction_only=save_prediction_only)
)
] ]
elif self.cfg.metric == 'RBOX': elif self.cfg.metric == 'RBOX':
# TODO: bias should be unified # TODO: bias should be unified
...@@ -516,7 +516,8 @@ class Trainer(object): ...@@ -516,7 +516,8 @@ class Trainer(object):
results.append(outs) results.append(outs)
# sniper # sniper
if type(self.dataset) == SniperCOCODataSet: if type(self.dataset) == SniperCOCODataSet:
results = self.dataset.anno_cropper.aggregate_chips_detections(results) results = self.dataset.anno_cropper.aggregate_chips_detections(
results)
for outs in results: for outs in results:
batch_res = get_infer_results(outs, clsid2catid) batch_res = get_infer_results(outs, clsid2catid)
......
...@@ -22,4 +22,8 @@ __all__ = metrics.__all__ + keypoint_metrics.__all__ ...@@ -22,4 +22,8 @@ __all__ = metrics.__all__ + keypoint_metrics.__all__
from . import mot_metrics from . import mot_metrics
from .mot_metrics import * from .mot_metrics import *
__all__ = __all__ + mot_metrics.__all__ __all__ = metrics.__all__ + mot_metrics.__all__
from . import mcmot_metrics
from .mcmot_metrics import *
__all__ = metrics.__all__ + mcmot_metrics.__all__
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import sys
import math
from collections import defaultdict
from motmetrics.math_util import quiet_divide
import numpy as np
import pandas as pd
import paddle
import paddle.nn.functional as F
from .metrics import Metric
import motmetrics as mm
import openpyxl
metrics = mm.metrics.motchallenge_metrics
mh = mm.metrics.create()
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['MCMOTEvaluator', 'MCMOTMetric']
METRICS_LIST = [
'num_frames', 'num_matches', 'num_switches', 'num_transfer', 'num_ascend',
'num_migrate', 'num_false_positives', 'num_misses', 'num_detections',
'num_objects', 'num_predictions', 'num_unique_objects', 'mostly_tracked',
'partially_tracked', 'mostly_lost', 'num_fragmentations', 'motp', 'mota',
'precision', 'recall', 'idfp', 'idfn', 'idtp', 'idp', 'idr', 'idf1'
]
NAME_MAP = {
'num_frames': 'num_frames',
'num_matches': 'num_matches',
'num_switches': 'IDs',
'num_transfer': 'IDt',
'num_ascend': 'IDa',
'num_migrate': 'IDm',
'num_false_positives': 'FP',
'num_misses': 'FN',
'num_detections': 'num_detections',
'num_objects': 'num_objects',
'num_predictions': 'num_predictions',
'num_unique_objects': 'GT',
'mostly_tracked': 'MT',
'partially_tracked': 'partially_tracked',
'mostly_lost': 'ML',
'num_fragmentations': 'FM',
'motp': 'MOTP',
'mota': 'MOTA',
'precision': 'Prcn',
'recall': 'Rcll',
'idfp': 'idfp',
'idfn': 'idfn',
'idtp': 'idtp',
'idp': 'IDP',
'idr': 'IDR',
'idf1': 'IDF1'
}
def parse_accs_metrics(seq_acc, index_name, verbose=False):
"""
Parse the evaluation indicators of multiple MOTAccumulator
"""
mh = mm.metrics.create()
summary = MCMOTEvaluator.get_summary(seq_acc, index_name, METRICS_LIST)
summary.loc['OVERALL', 'motp'] = (summary['motp'] * summary['num_detections']).sum() / \
summary.loc['OVERALL', 'num_detections']
if verbose:
strsummary = mm.io.render_summary(
summary, formatters=mh.formatters, namemap=NAME_MAP)
print(strsummary)
return summary
def seqs_overall_metrics(summary_df, verbose=False):
"""
Calculate overall metrics for multiple sequences
"""
add_col = [
'num_frames', 'num_matches', 'num_switches', 'num_transfer',
'num_ascend', 'num_migrate', 'num_false_positives', 'num_misses',
'num_detections', 'num_objects', 'num_predictions',
'num_unique_objects', 'mostly_tracked', 'partially_tracked',
'mostly_lost', 'num_fragmentations', 'idfp', 'idfn', 'idtp'
]
calc_col = ['motp', 'mota', 'precision', 'recall', 'idp', 'idr', 'idf1']
calc_df = summary_df.copy()
overall_dic = {}
for col in add_col:
overall_dic[col] = calc_df[col].sum()
for col in calc_col:
overall_dic[col] = getattr(MCMOTMetricOverall, col + '_overall')(
calc_df, overall_dic)
overall_df = pd.DataFrame(overall_dic, index=['overall_calc'])
calc_df = pd.concat([calc_df, overall_df])
if verbose:
mh = mm.metrics.create()
str_calc_df = mm.io.render_summary(
calc_df, formatters=mh.formatters, namemap=NAME_MAP)
print(str_calc_df)
return calc_df
class MCMOTMetricOverall(object):
def motp_overall(summary_df, overall_dic):
motp = quiet_divide((summary_df['motp'] *
summary_df['num_detections']).sum(),
overall_dic['num_detections'])
return motp
def mota_overall(summary_df, overall_dic):
del summary_df
mota = 1. - quiet_divide(
(overall_dic['num_misses'] + overall_dic['num_switches'] +
overall_dic['num_false_positives']), overall_dic['num_objects'])
return mota
def precision_overall(summary_df, overall_dic):
del summary_df
precision = quiet_divide(overall_dic['num_detections'], (
overall_dic['num_false_positives'] + overall_dic['num_detections']))
return precision
def recall_overall(summary_df, overall_dic):
del summary_df
recall = quiet_divide(overall_dic['num_detections'],
overall_dic['num_objects'])
return recall
def idp_overall(summary_df, overall_dic):
del summary_df
idp = quiet_divide(overall_dic['idtp'],
(overall_dic['idtp'] + overall_dic['idfp']))
return idp
def idr_overall(summary_df, overall_dic):
del summary_df
idr = quiet_divide(overall_dic['idtp'],
(overall_dic['idtp'] + overall_dic['idfn']))
return idr
def idf1_overall(summary_df, overall_dic):
del summary_df
idf1 = quiet_divide(2. * overall_dic['idtp'], (
overall_dic['num_objects'] + overall_dic['num_predictions']))
return idf1
def read_mcmot_results_union(filename, is_gt, is_ignore):
results_dict = dict()
if os.path.isfile(filename):
all_result = np.loadtxt(filename, delimiter=',')
if all_result.shape[0] == 0 or all_result.shape[1] < 7:
return results_dict
if is_ignore:
return results_dict
if is_gt:
# only for test use
all_result = all_result[all_result[:, 7] != 0]
all_result[:, 7] = all_result[:, 7] - 1
if all_result.shape[0] == 0:
return results_dict
class_unique = np.unique(all_result[:, 7])
last_max_id = 0
result_cls_list = []
for cls in class_unique:
result_cls_split = all_result[all_result[:, 7] == cls]
result_cls_split[:, 1] = result_cls_split[:, 1] + last_max_id
# make sure track id different between every category
last_max_id = max(np.unique(result_cls_split[:, 1])) + 1
result_cls_list.append(result_cls_split)
results_con = np.concatenate(result_cls_list)
for line in range(len(results_con)):
linelist = results_con[line]
fid = int(linelist[0])
if fid < 1:
continue
results_dict.setdefault(fid, list())
if is_gt:
score = 1
else:
score = float(linelist[6])
tlwh = tuple(map(float, linelist[2:6]))
target_id = int(linelist[1])
cls = int(linelist[7])
results_dict[fid].append((tlwh, target_id, cls, score))
return results_dict
def read_mcmot_results(filename, is_gt, is_ignore):
results_dict = dict()
if os.path.isfile(filename):
with open(filename, 'r') as f:
for line in f.readlines():
linelist = line.strip().split(',')
if len(linelist) < 7:
continue
fid = int(linelist[0])
if fid < 1:
continue
cid = int(linelist[7])
if is_gt:
score = 1
# only for test use
cid -= 1
else:
score = float(linelist[6])
cls_result_dict = results_dict.setdefault(cid, dict())
cls_result_dict.setdefault(fid, list())
tlwh = tuple(map(float, linelist[2:6]))
target_id = int(linelist[1])
cls_result_dict[fid].append((tlwh, target_id, score))
return results_dict
def read_results(filename,
data_type,
is_gt=False,
is_ignore=False,
multi_class=False,
union=False):
if data_type in ['mcmot', 'lab']:
if multi_class:
if union:
# The results are evaluated by union all the categories.
# Track IDs between different categories cannot be duplicate.
read_fun = read_mcmot_results_union
else:
# The results are evaluated separately by category.
read_fun = read_mcmot_results
else:
raise ValueError('multi_class: {}, MCMOT should have cls_id.'.
format(multi_class))
else:
raise ValueError('Unknown data type: {}'.format(data_type))
return read_fun(filename, is_gt, is_ignore)
def unzip_objs(objs):
if len(objs) > 0:
tlwhs, ids, scores = zip(*objs)
else:
tlwhs, ids, scores = [], [], []
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
return tlwhs, ids, scores
def unzip_objs_cls(objs):
if len(objs) > 0:
tlwhs, ids, cls, scores = zip(*objs)
else:
tlwhs, ids, cls, scores = [], [], [], []
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
ids = np.array(ids)
cls = np.array(cls)
scores = np.array(scores)
return tlwhs, ids, cls, scores
class MCMOTEvaluator(object):
def __init__(self, data_root, seq_name, data_type, num_classes):
self.data_root = data_root
self.seq_name = seq_name
self.data_type = data_type
self.num_classes = num_classes
self.load_annotations()
self.reset_accumulator()
self.class_accs = []
def load_annotations(self):
assert self.data_type == 'mcmot'
self.gt_filename = os.path.join(self.data_root, '../', '../',
'sequences',
'{}.txt'.format(self.seq_name))
def reset_accumulator(self):
import motmetrics as mm
mm.lap.default_solver = 'lap'
self.acc = mm.MOTAccumulator(auto_id=True)
def eval_frame_dict(self, trk_objs, gt_objs, rtn_events=False, union=False):
import motmetrics as mm
mm.lap.default_solver = 'lap'
if union:
trk_tlwhs, trk_ids, trk_cls = unzip_objs_cls(trk_objs)[:3]
gt_tlwhs, gt_ids, gt_cls = unzip_objs_cls(gt_objs)[:3]
# get distance matrix
iou_distance = mm.distances.iou_matrix(
gt_tlwhs, trk_tlwhs, max_iou=0.5)
# Set the distance between objects of different categories to nan
gt_cls_len = len(gt_cls)
trk_cls_len = len(trk_cls)
# When the number of GT or Trk is 0, iou_distance dimension is (0,0)
if gt_cls_len != 0 and trk_cls_len != 0:
gt_cls = gt_cls.reshape(gt_cls_len, 1)
gt_cls = np.repeat(gt_cls, trk_cls_len, axis=1)
trk_cls = trk_cls.reshape(1, trk_cls_len)
trk_cls = np.repeat(trk_cls, gt_cls_len, axis=0)
iou_distance = np.where(gt_cls == trk_cls, iou_distance, np.nan)
else:
trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
# get distance matrix
iou_distance = mm.distances.iou_matrix(
gt_tlwhs, trk_tlwhs, max_iou=0.5)
self.acc.update(gt_ids, trk_ids, iou_distance)
if rtn_events and iou_distance.size > 0 and hasattr(self.acc,
'mot_events'):
events = self.acc.mot_events # only supported by https://github.com/longcw/py-motmetrics
else:
events = None
return events
def eval_file(self, result_filename):
# evaluation of each category
gt_frame_dict = read_results(
self.gt_filename,
self.data_type,
is_gt=True,
multi_class=True,
union=False)
result_frame_dict = read_results(
result_filename,
self.data_type,
is_gt=False,
multi_class=True,
union=False)
for cid in range(self.num_classes):
self.reset_accumulator()
cls_result_frame_dict = result_frame_dict.setdefault(cid, dict())
cls_gt_frame_dict = gt_frame_dict.setdefault(cid, dict())
# only labeled frames will be evaluated
frames = sorted(list(set(cls_gt_frame_dict.keys())))
for frame_id in frames:
trk_objs = cls_result_frame_dict.get(frame_id, [])
gt_objs = cls_gt_frame_dict.get(frame_id, [])
self.eval_frame_dict(trk_objs, gt_objs, rtn_events=False)
self.class_accs.append(self.acc)
return self.class_accs
@staticmethod
def get_summary(accs,
names,
metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1',
'precision', 'recall')):
import motmetrics as mm
mm.lap.default_solver = 'lap'
names = copy.deepcopy(names)
if metrics is None:
metrics = mm.metrics.motchallenge_metrics
metrics = copy.deepcopy(metrics)
mh = mm.metrics.create()
summary = mh.compute_many(
accs, metrics=metrics, names=names, generate_overall=True)
return summary
@staticmethod
def save_summary(summary, filename):
import pandas as pd
writer = pd.ExcelWriter(filename)
summary.to_excel(writer)
writer.save()
class MCMOTMetric(Metric):
def __init__(self, num_classes, save_summary=False):
self.num_classes = num_classes
self.save_summary = save_summary
self.MCMOTEvaluator = MCMOTEvaluator
self.result_root = None
self.reset()
self.seqs_overall = defaultdict(list)
def reset(self):
self.accs = []
self.seqs = []
def update(self, data_root, seq, data_type, result_root, result_filename):
evaluator = self.MCMOTEvaluator(data_root, seq, data_type,
self.num_classes)
seq_acc = evaluator.eval_file(result_filename)
self.accs.append(seq_acc)
self.seqs.append(seq)
self.result_root = result_root
cls_index_name = [
'{}_{}'.format(seq, i) for i in range(self.num_classes)
]
summary = parse_accs_metrics(seq_acc, cls_index_name)
summary.rename(
index={'OVERALL': '{}_OVERALL'.format(seq)}, inplace=True)
for row in range(len(summary)):
self.seqs_overall[row].append(summary.iloc[row:row + 1])
def accumulate(self):
self.cls_summary_list = []
for row in range(self.num_classes):
seqs_cls_df = pd.concat(self.seqs_overall[row])
seqs_cls_summary = seqs_overall_metrics(seqs_cls_df)
cls_summary_overall = seqs_cls_summary.iloc[-1:].copy()
cls_summary_overall.rename(
index={'overall_calc': 'overall_calc_{}'.format(row)},
inplace=True)
self.cls_summary_list.append(cls_summary_overall)
def log(self):
seqs_summary = seqs_overall_metrics(
pd.concat(self.seqs_overall[self.num_classes]), verbose=True)
class_summary = seqs_overall_metrics(
pd.concat(self.cls_summary_list), verbose=True)
def get_results(self):
return 1
...@@ -79,7 +79,7 @@ class CenterNet(BaseArch): ...@@ -79,7 +79,7 @@ class CenterNet(BaseArch):
def get_pred(self): def get_pred(self):
head_out = self._forward() head_out = self._forward()
if self.for_mot: if self.for_mot:
bbox, bbox_inds = self.post_process( bbox, bbox_inds, topk_clses = self.post_process(
head_out['heatmap'], head_out['heatmap'],
head_out['size'], head_out['size'],
head_out['offset'], head_out['offset'],
...@@ -88,10 +88,11 @@ class CenterNet(BaseArch): ...@@ -88,10 +88,11 @@ class CenterNet(BaseArch):
output = { output = {
"bbox": bbox, "bbox": bbox,
"bbox_inds": bbox_inds, "bbox_inds": bbox_inds,
"topk_clses": topk_clses,
"neck_feat": head_out['neck_feat'] "neck_feat": head_out['neck_feat']
} }
else: else:
bbox, bbox_num = self.post_process( bbox, bbox_num, _ = self.post_process(
head_out['heatmap'], head_out['heatmap'],
head_out['size'], head_out['size'],
head_out['offset'], head_out['offset'],
......
...@@ -86,13 +86,9 @@ class FairMOT(BaseArch): ...@@ -86,13 +86,9 @@ class FairMOT(BaseArch):
loss.update({'reid_loss': reid_loss}) loss.update({'reid_loss': reid_loss})
return loss return loss
else: else:
embedding = self.reid(neck_feat, self.inputs) pred_dets, pred_embs = self.reid(
bbox_inds = det_outs['bbox_inds'] neck_feat, self.inputs, det_outs['bbox'], det_outs['bbox_inds'],
embedding = paddle.transpose(embedding, [0, 2, 3, 1]) det_outs['topk_clses'])
embedding = paddle.reshape(embedding,
[-1, paddle.shape(embedding)[-1]])
pred_embs = paddle.gather(embedding, bbox_inds)
pred_dets = det_outs['bbox']
return pred_dets, pred_embs return pred_dets, pred_embs
def get_pred(self): def get_pred(self):
......
...@@ -59,7 +59,7 @@ class CenterNetHead(nn.Layer): ...@@ -59,7 +59,7 @@ class CenterNetHead(nn.Layer):
""" """
Args: Args:
in_channels (int): the channel number of input to CenterNetHead. in_channels (int): the channel number of input to CenterNetHead.
num_classes (int): the number of classes, 80 by default. num_classes (int): the number of classes, 80 (COCO dataset) by default.
head_planes (int): the channel number in all head, 256 by default. head_planes (int): the channel number in all head, 256 by default.
heatmap_weight (float): the weight of heatmap loss, 1 by default. heatmap_weight (float): the weight of heatmap loss, 1 by default.
regress_ltrb (bool): whether to regress left/top/right/bottom or regress_ltrb (bool): whether to regress left/top/right/bottom or
...@@ -83,6 +83,7 @@ class CenterNetHead(nn.Layer): ...@@ -83,6 +83,7 @@ class CenterNetHead(nn.Layer):
offset_weight=1, offset_weight=1,
iou_weight=0): iou_weight=0):
super(CenterNetHead, self).__init__() super(CenterNetHead, self).__init__()
self.regress_ltrb = regress_ltrb
self.weights = { self.weights = {
'heatmap': heatmap_weight, 'heatmap': heatmap_weight,
'size': size_weight, 'size': size_weight,
...@@ -196,7 +197,14 @@ class CenterNetHead(nn.Layer): ...@@ -196,7 +197,14 @@ class CenterNetHead(nn.Layer):
pos_num = size_mask.sum() pos_num = size_mask.sum()
size_mask.stop_gradient = True size_mask.stop_gradient = True
if self.size_loss == 'L1': if self.size_loss == 'L1':
size_target = inputs['size'] if self.regress_ltrb:
size_target = inputs['size']
# shape: [bs, max_per_img, 4]
else:
size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :,
2:]
# shape: [bs, max_per_img, 2]
size_target.stop_gradient = True size_target.stop_gradient = True
size_loss = F.l1_loss( size_loss = F.l1_loss(
pos_size * size_mask, size_target * size_mask, reduction='sum') pos_size * size_mask, size_target * size_mask, reduction='sum')
......
...@@ -11,11 +11,9 @@ ...@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np import numpy as np
from collections import defaultdict
from collections import deque, OrderedDict from collections import deque, OrderedDict
from ..matching import jde_matching as matching from ..matching import jde_matching as matching
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
...@@ -40,7 +38,7 @@ class TrackState(object): ...@@ -40,7 +38,7 @@ class TrackState(object):
@register @register
@serializable @serializable
class BaseTrack(object): class BaseTrack(object):
_count = 0 _count_dict = defaultdict(int) # support single class and multi classes
track_id = 0 track_id = 0
is_activated = False is_activated = False
...@@ -62,9 +60,23 @@ class BaseTrack(object): ...@@ -62,9 +60,23 @@ class BaseTrack(object):
return self.frame_id return self.frame_id
@staticmethod @staticmethod
def next_id(): def next_id(cls_id):
BaseTrack._count += 1 BaseTrack._count_dict[cls_id] += 1
return BaseTrack._count return BaseTrack._count_dict[cls_id]
# @even: reset track id
@staticmethod
def init_count(num_classes):
"""
Initiate _count for all object classes
:param num_classes:
"""
for cls_id in range(num_classes):
BaseTrack._count_dict[cls_id] = 0
@staticmethod
def reset_track_count(cls_id):
BaseTrack._count_dict[cls_id] = 0
def activate(self, *args): def activate(self, *args):
raise NotImplementedError raise NotImplementedError
...@@ -85,7 +97,15 @@ class BaseTrack(object): ...@@ -85,7 +97,15 @@ class BaseTrack(object):
@register @register
@serializable @serializable
class STrack(BaseTrack): class STrack(BaseTrack):
def __init__(self, tlwh, score, temp_feat, buffer_size=30): def __init__(self,
tlwh,
score,
temp_feat,
num_classes,
cls_id,
buff_size=30):
# object class id
self.cls_id = cls_id
# wait activate # wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float) self._tlwh = np.asarray(tlwh, dtype=np.float)
self.kalman_filter = None self.kalman_filter = None
...@@ -93,20 +113,21 @@ class STrack(BaseTrack): ...@@ -93,20 +113,21 @@ class STrack(BaseTrack):
self.is_activated = False self.is_activated = False
self.score = score self.score = score
self.tracklet_len = 0 self.track_len = 0
self.smooth_feat = None self.smooth_feat = None
self.update_features(temp_feat) self.update_features(temp_feat)
self.features = deque([], maxlen=buffer_size) self.features = deque([], maxlen=buff_size)
self.alpha = 0.9 self.alpha = 0.9
def update_features(self, feat): def update_features(self, feat):
# L2 normalizing
feat /= np.linalg.norm(feat) feat /= np.linalg.norm(feat)
self.curr_feat = feat self.curr_feat = feat
if self.smooth_feat is None: if self.smooth_feat is None:
self.smooth_feat = feat self.smooth_feat = feat
else: else:
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha self.smooth_feat = self.alpha * self.smooth_feat + (1.0 - self.alpha
) * feat ) * feat
self.features.append(feat) self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat) self.smooth_feat /= np.linalg.norm(self.smooth_feat)
...@@ -119,54 +140,60 @@ class STrack(BaseTrack): ...@@ -119,54 +140,60 @@ class STrack(BaseTrack):
self.covariance) self.covariance)
@staticmethod @staticmethod
def multi_predict(stracks, kalman_filter): def multi_predict(tracks, kalman_filter):
if len(stracks) > 0: if len(tracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_mean = np.asarray([track.mean.copy() for track in tracks])
multi_covariance = np.asarray([st.covariance for st in stracks]) multi_covariance = np.asarray(
for i, st in enumerate(stracks): [track.covariance for track in tracks])
for i, st in enumerate(tracks):
if st.state != TrackState.Tracked: if st.state != TrackState.Tracked:
multi_mean[i][7] = 0 multi_mean[i][7] = 0
multi_mean, multi_covariance = kalman_filter.multi_predict( multi_mean, multi_covariance = kalman_filter.multi_predict(
multi_mean, multi_covariance) multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean tracks[i].mean = mean
stracks[i].covariance = cov tracks[i].covariance = cov
def reset_track_id(self):
self.reset_track_count(self.cls_id)
def activate(self, kalman_filter, frame_id): def activate(self, kalman_filter, frame_id):
"""Start a new tracklet""" """Start a new track"""
self.kalman_filter = kalman_filter self.kalman_filter = kalman_filter
self.track_id = self.next_id() # update track id for the object class
self.track_id = self.next_id(self.cls_id)
self.mean, self.covariance = self.kalman_filter.initiate( self.mean, self.covariance = self.kalman_filter.initiate(
self.tlwh_to_xyah(self._tlwh)) self.tlwh_to_xyah(self._tlwh))
self.tracklet_len = 0 self.track_len = 0
self.state = TrackState.Tracked self.state = TrackState.Tracked # set flag 'tracked'
if frame_id == 1:
if frame_id == 1: # to record the first frame's detection result
self.is_activated = True self.is_activated = True
self.frame_id = frame_id self.frame_id = frame_id
self.start_frame = frame_id self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False): def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)) self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
self.update_features(new_track.curr_feat) self.update_features(new_track.curr_feat)
self.tracklet_len = 0 self.track_len = 0
self.state = TrackState.Tracked self.state = TrackState.Tracked
self.is_activated = True self.is_activated = True
self.frame_id = frame_id self.frame_id = frame_id
if new_id: if new_id: # update track id for the object class
self.track_id = self.next_id() self.track_id = self.next_id(self.cls_id)
def update(self, new_track, frame_id, update_feature=True): def update(self, new_track, frame_id, update_feature=True):
self.frame_id = frame_id self.frame_id = frame_id
self.tracklet_len += 1 self.track_len += 1
new_tlwh = new_track.tlwh new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
self.state = TrackState.Tracked self.state = TrackState.Tracked # set flag 'tracked'
self.is_activated = True self.is_activated = True # set flag 'activated'
self.score = new_track.score self.score = new_track.score
if update_feature: if update_feature:
...@@ -174,12 +201,12 @@ class STrack(BaseTrack): ...@@ -174,12 +201,12 @@ class STrack(BaseTrack):
@property @property
def tlwh(self): def tlwh(self):
""" """Get current position in bounding box format `(top left x, top left y,
Get current position in bounding box format `(top left x, top left y, width, height)`.
width, height)`.
""" """
if self.mean is None: if self.mean is None:
return self._tlwh.copy() return self._tlwh.copy()
ret = self.mean[:4].copy() ret = self.mean[:4].copy()
ret[2] *= ret[3] ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2 ret[:2] -= ret[2:] / 2
...@@ -187,8 +214,7 @@ class STrack(BaseTrack): ...@@ -187,8 +214,7 @@ class STrack(BaseTrack):
@property @property
def tlbr(self): def tlbr(self):
""" """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`. `(top left, bottom right)`.
""" """
ret = self.tlwh.copy() ret = self.tlwh.copy()
...@@ -197,8 +223,7 @@ class STrack(BaseTrack): ...@@ -197,8 +223,7 @@ class STrack(BaseTrack):
@staticmethod @staticmethod
def tlwh_to_xyah(tlwh): def tlwh_to_xyah(tlwh):
""" """Convert bounding box to format `(center x, center y, aspect ratio,
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`. height)`, where the aspect ratio is `width / height`.
""" """
ret = np.asarray(tlwh).copy() ret = np.asarray(tlwh).copy()
...@@ -222,8 +247,8 @@ class STrack(BaseTrack): ...@@ -222,8 +247,8 @@ class STrack(BaseTrack):
return ret return ret
def __repr__(self): def __repr__(self):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, return 'OT_({}-{})_({}-{})'.format(self.cls_id, self.track_id,
self.end_frame) self.start_frame, self.end_frame)
def joint_stracks(tlista, tlistb): def joint_stracks(tlista, tlistb):
......
...@@ -17,10 +17,13 @@ import cv2 ...@@ -17,10 +17,13 @@ import cv2
import time import time
import paddle import paddle
import numpy as np import numpy as np
from .visualization import plot_tracking_dict
__all__ = [ __all__ = [
'Timer', 'MOTTimer',
'Detection', 'Detection',
'write_mot_results',
'save_vis_results',
'load_det_results', 'load_det_results',
'preprocess_reid', 'preprocess_reid',
'get_crops', 'get_crops',
...@@ -29,7 +32,7 @@ __all__ = [ ...@@ -29,7 +32,7 @@ __all__ = [
] ]
class Timer(object): class MOTTimer(object):
""" """
This class used to compute and print the current FPS while evaling. This class used to compute and print the current FPS while evaling.
""" """
...@@ -106,6 +109,68 @@ class Detection(object): ...@@ -106,6 +109,68 @@ class Detection(object):
return ret return ret
def write_mot_results(filename, results, data_type='mot', num_classes=1):
# support single and multi classes
if data_type in ['mot', 'mcmot']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},{cls_id},-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
f = open(filename, 'w')
for cls_id in range(num_classes):
for frame_id, tlwhs, tscores, track_ids in results[cls_id]:
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0: continue
if data_type == 'kitti':
frame_id -= 1
elif data_type == 'mot':
cls_id = -1
elif data_type == 'mcmot':
cls_id = cls_id
x1, y1, w, h = tlwh
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
w=w,
h=h,
score=score,
cls_id=cls_id)
f.write(line)
print('MOT results save in {}'.format(filename))
def save_vis_results(data,
frame_id,
online_ids,
online_tlwhs,
online_scores,
average_time,
show_image,
save_dir,
num_classes=1):
if show_image or save_dir is not None:
assert 'ori_image' in data
img0 = data['ori_image'].numpy()[0]
online_im = plot_tracking_dict(
img0,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / average_time)
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
def load_det_results(det_file, num_frames): def load_det_results(det_file, num_frames):
assert os.path.exists(det_file) and os.path.isfile(det_file), \ assert os.path.exists(det_file) and os.path.isfile(det_file), \
'{} is not exist or not a file.'.format(det_file) '{} is not exist or not a file.'.format(det_file)
......
...@@ -16,28 +16,12 @@ import cv2 ...@@ -16,28 +16,12 @@ import cv2
import numpy as np import numpy as np
def tlwhs_to_tlbrs(tlwhs):
tlbrs = np.copy(tlwhs)
if len(tlbrs) == 0:
return tlbrs
tlbrs[:, 2] += tlwhs[:, 0]
tlbrs[:, 3] += tlwhs[:, 1]
return tlbrs
def get_color(idx): def get_color(idx):
idx = idx * 3 idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255) color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color return color
def resize_image(image, max_size=800):
if max(image.shape[:2]) > max_size:
scale = float(max_size) / max(image.shape[:2])
image = cv2.resize(image, None, fx=scale, fy=scale)
return image
def plot_tracking(image, def plot_tracking(image,
tlwhs, tlwhs,
obj_ids, obj_ids,
...@@ -92,44 +76,67 @@ def plot_tracking(image, ...@@ -92,44 +76,67 @@ def plot_tracking(image,
return im return im
def plot_trajectory(image, tlwhs, track_ids): def plot_tracking_dict(image,
image = image.copy() num_classes,
for one_tlwhs, track_id in zip(tlwhs, track_ids): tlwhs_dict,
color = get_color(int(track_id)) obj_ids_dict,
for tlwh in one_tlwhs: scores_dict,
x1, y1, w, h = tuple(map(int, tlwh)) frame_id=0,
cv2.circle( fps=0.,
image, (int(x1 + 0.5 * w), int(y1 + h)), 2, color, thickness=2) ids2=None):
return image im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
def plot_detections(image, tlbrs, scores=None, color=(255, 0, 0), ids=None): top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
im = np.copy(image)
text_scale = max(1, image.shape[1] / 800.) text_scale = max(1, image.shape[1] / 1600.)
thickness = 2 if text_scale > 1.3 else 1 text_thickness = 2
for i, det in enumerate(tlbrs): line_thickness = max(1, int(image.shape[1] / 500.))
x1, y1, x2, y2 = np.asarray(det[:4], dtype=np.int)
if len(det) >= 7: radius = max(5, int(im_w / 140.))
label = 'det' if det[5] > 0 else 'trk'
if ids is not None: for cls_id in range(num_classes):
text = '{}# {:.2f}: {:d}'.format(label, det[6], ids[i]) tlwhs = tlwhs_dict[cls_id]
cv2.putText( obj_ids = obj_ids_dict[cls_id]
im, scores = scores_dict[cls_id]
text, (x1, y1 + 30), cv2.putText(
cv2.FONT_HERSHEY_PLAIN, im,
text_scale, (0, 255, 255), 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
thickness=thickness) (0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
if num_classes == 1:
id_text = '{}'.format(int(obj_id))
else: else:
text = '{}# {:.2f}'.format(label, det[6]) id_text = 'class{}_id{}'.format(cls_id, int(obj_id))
if scores is not None: _line_thickness = 1 if obj_id <= 0 else line_thickness
text = '{:.2f}'.format(scores[i]) color = get_color(abs(obj_id))
cv2.rectangle(
im,
intbox[0:2],
intbox[2:4],
color=color,
thickness=line_thickness)
cv2.putText( cv2.putText(
im, im,
text, (x1, y1 + 30), id_text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN, cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255), text_scale, (0, 0, 255),
thickness=thickness) thickness=text_thickness)
cv2.rectangle(im, (x1, y1), (x2, y2), color, 2) if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im return im
...@@ -415,7 +415,6 @@ class CenterNetPostProcess(TTFBox): ...@@ -415,7 +415,6 @@ class CenterNetPostProcess(TTFBox):
regress_ltrb (bool): whether to regress left/top/right/bottom or regress_ltrb (bool): whether to regress left/top/right/bottom or
width/height for a box, true by default. width/height for a box, true by default.
for_mot (bool): whether return other features used in tracking model. for_mot (bool): whether return other features used in tracking model.
""" """
__shared__ = ['down_ratio', 'for_mot'] __shared__ = ['down_ratio', 'for_mot']
...@@ -433,9 +432,9 @@ class CenterNetPostProcess(TTFBox): ...@@ -433,9 +432,9 @@ class CenterNetPostProcess(TTFBox):
def __call__(self, hm, wh, reg, im_shape, scale_factor): def __call__(self, hm, wh, reg, im_shape, scale_factor):
heat = self._simple_nms(hm) heat = self._simple_nms(hm)
scores, inds, clses, ys, xs = self._topk(heat) scores, inds, topk_clses, ys, xs = self._topk(heat)
scores = paddle.tensor.unsqueeze(scores, [1]) scores = paddle.tensor.unsqueeze(scores, [1])
clses = paddle.tensor.unsqueeze(clses, [1]) clses = paddle.tensor.unsqueeze(topk_clses, [1])
reg_t = paddle.transpose(reg, [0, 2, 3, 1]) reg_t = paddle.transpose(reg, [0, 2, 3, 1])
# Like TTFBox, batch size is 1. # Like TTFBox, batch size is 1.
...@@ -486,10 +485,10 @@ class CenterNetPostProcess(TTFBox): ...@@ -486,10 +485,10 @@ class CenterNetPostProcess(TTFBox):
bboxes = paddle.divide(bboxes, scale_expand) bboxes = paddle.divide(bboxes, scale_expand)
if self.for_mot: if self.for_mot:
results = paddle.concat([bboxes, scores, clses], axis=1) results = paddle.concat([bboxes, scores, clses], axis=1)
return results, inds return results, inds, topk_clses
else: else:
results = paddle.concat([clses, scores, bboxes], axis=1) results = paddle.concat([clses, scores, bboxes], axis=1)
return results, paddle.shape(results)[0:1] return results, paddle.shape(results)[0:1], topk_clses
@register @register
......
...@@ -26,21 +26,27 @@ __all__ = ['FairMOTEmbeddingHead'] ...@@ -26,21 +26,27 @@ __all__ = ['FairMOTEmbeddingHead']
@register @register
class FairMOTEmbeddingHead(nn.Layer): class FairMOTEmbeddingHead(nn.Layer):
__shared__ = ['num_classes']
""" """
Args: Args:
in_channels (int): the channel number of input to FairMOTEmbeddingHead. in_channels (int): the channel number of input to FairMOTEmbeddingHead.
ch_head (int): the channel of features before fed into embedding, 256 by default. ch_head (int): the channel of features before fed into embedding, 256 by default.
ch_emb (int): the channel of the embedding feature, 128 by default. ch_emb (int): the channel of the embedding feature, 128 by default.
num_identifiers (int): the number of identifiers, 14455 by default. num_identities_dict (dict): the number of identities of each category,
support single class and multi-calss, {0: 14455} as default.
""" """
def __init__(self, def __init__(self,
in_channels, in_channels,
ch_head=256, ch_head=256,
ch_emb=128, ch_emb=128,
num_identifiers=14455): num_classes=1,
num_identities_dict={0: 14455}):
super(FairMOTEmbeddingHead, self).__init__() super(FairMOTEmbeddingHead, self).__init__()
assert num_classes >= 1
self.num_classes = num_classes
self.ch_emb = ch_emb
self.num_identities_dict = num_identities_dict
self.reid = nn.Sequential( self.reid = nn.Sequential(
ConvLayer( ConvLayer(
in_channels, ch_head, kernel_size=3, padding=1, bias=True), in_channels, ch_head, kernel_size=3, padding=1, bias=True),
...@@ -50,15 +56,27 @@ class FairMOTEmbeddingHead(nn.Layer): ...@@ -50,15 +56,27 @@ class FairMOTEmbeddingHead(nn.Layer):
param_attr = paddle.ParamAttr(initializer=KaimingUniform()) param_attr = paddle.ParamAttr(initializer=KaimingUniform())
bound = 1 / math.sqrt(ch_emb) bound = 1 / math.sqrt(ch_emb)
bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound)) bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
self.classifier = nn.Linear(
ch_emb,
num_identifiers,
weight_attr=param_attr,
bias_attr=bias_attr)
self.reid_loss = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') self.reid_loss = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum')
# When num_identifiers is 1, emb_scale is set as 1
self.emb_scale = math.sqrt(2) * math.log( if num_classes == 1:
num_identifiers - 1) if num_identifiers > 1 else 1 nID = self.num_identities_dict[0] # single class
self.classifier = nn.Linear(
ch_emb,
nID,
weight_attr=param_attr,
bias_attr=bias_attr)
# When num_identities(nID) is 1, emb_scale is set as 1
self.emb_scale = math.sqrt(2) * math.log(
nID - 1) if nID > 1 else 1
else:
self.classifiers = dict()
self.emb_scale_dict = dict()
for cls_id, nID in self.num_identities_dict.items():
self.classifiers[str(cls_id)] = nn.Linear(
ch_emb, nID, weight_attr=param_attr, bias_attr=bias_attr)
# When num_identities(nID) is 1, emb_scale is set as 1
self.emb_scale_dict[str(cls_id)] = math.sqrt(2) * math.log(
nID - 1) if nID > 1 else 1
@classmethod @classmethod
def from_config(cls, cfg, input_shape): def from_config(cls, cfg, input_shape):
...@@ -66,14 +84,56 @@ class FairMOTEmbeddingHead(nn.Layer): ...@@ -66,14 +84,56 @@ class FairMOTEmbeddingHead(nn.Layer):
input_shape = input_shape[0] input_shape = input_shape[0]
return {'in_channels': input_shape.channels} return {'in_channels': input_shape.channels}
def forward(self, feat, inputs): def process_by_class(self, det_outs, embedding, bbox_inds, topk_clses):
pred_dets, pred_embs = [], []
for cls_id in range(self.num_classes):
inds_masks = topk_clses == cls_id
inds_masks = paddle.cast(inds_masks, 'float32')
pos_num = inds_masks.sum().numpy()
if pos_num == 0:
continue
cls_inds_mask = inds_masks > 0
bbox_mask = paddle.nonzero(cls_inds_mask)
cls_det_outs = paddle.gather_nd(det_outs, bbox_mask)
pred_dets.append(cls_det_outs)
cls_inds = paddle.masked_select(bbox_inds, cls_inds_mask)
cls_inds = cls_inds.unsqueeze(-1)
cls_embedding = paddle.gather_nd(embedding, cls_inds)
pred_embs.append(cls_embedding)
return paddle.concat(pred_dets), paddle.concat(pred_embs)
def forward(self,
feat,
inputs,
det_outs=None,
bbox_inds=None,
topk_clses=None):
reid_feat = self.reid(feat) reid_feat = self.reid(feat)
if self.training: if self.training:
loss = self.get_loss(reid_feat, inputs) if self.num_classes == 1:
loss = self.get_loss(reid_feat, inputs)
else:
loss = self.get_mc_loss(reid_feat, inputs)
return loss return loss
else: else:
assert det_outs is not None and bbox_inds is not None
reid_feat = F.normalize(reid_feat) reid_feat = F.normalize(reid_feat)
return reid_feat embedding = paddle.transpose(reid_feat, [0, 2, 3, 1])
embedding = paddle.reshape(embedding, [-1, self.ch_emb])
# embedding shape: [bs * h * w, ch_emb]
if self.num_classes == 1:
pred_dets = det_outs
pred_embs = paddle.gather(embedding, bbox_inds)
else:
pred_dets, pred_embs = self.process_by_class(
det_outs, embedding, bbox_inds, topk_clses)
return pred_dets, pred_embs
def get_loss(self, feat, inputs): def get_loss(self, feat, inputs):
index = inputs['index'] index = inputs['index']
...@@ -113,3 +173,56 @@ class FairMOTEmbeddingHead(nn.Layer): ...@@ -113,3 +173,56 @@ class FairMOTEmbeddingHead(nn.Layer):
loss = loss / count loss = loss / count
return loss return loss
def get_mc_loss(self, feat, inputs):
# feat.shape = [bs, ch_emb, h, w]
assert 'cls_id_map' in inputs and 'cls_tr_ids' in inputs
index = inputs['index']
mask = inputs['index_mask']
cls_id_map = inputs['cls_id_map'] # [bs, h, w]
cls_tr_ids = inputs['cls_tr_ids'] # [bs, num_classes, h, w]
feat = paddle.transpose(feat, perm=[0, 2, 3, 1])
feat_n, feat_h, feat_w, feat_c = feat.shape
feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
for i in range(feat_n):
batch_ind = paddle.full(
shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
batch_inds.append(batch_ind)
batch_inds = paddle.concat(batch_inds, axis=0)
index = paddle.concat(x=[batch_inds, index], axis=2)
feat = paddle.gather_nd(feat, index=index)
mask = paddle.unsqueeze(mask, axis=2)
mask = paddle.expand_as(mask, feat)
mask.stop_gradient = True
feat = paddle.masked_select(feat, mask > 0)
feat = paddle.reshape(feat, shape=[-1, feat_c])
reid_losses = 0
for cls_id, id_num in self.num_identities_dict.items():
# target
cur_cls_tr_ids = paddle.reshape(
cls_tr_ids[:, cls_id, :, :], shape=[feat_n, -1]) # [bs, h*w]
cls_id_target = paddle.gather_nd(cur_cls_tr_ids, index=index)
mask = inputs['index_mask']
cls_id_target = paddle.masked_select(cls_id_target, mask > 0)
cls_id_target.stop_gradient = True
# feat
cls_id_feat = self.emb_scale_dict[str(cls_id)] * F.normalize(feat)
cls_id_pred = self.classifiers[str(cls_id)](cls_id_feat)
loss = self.reid_loss(cls_id_pred, cls_id_target)
valid = (cls_id_target != self.reid_loss.ignore_index)
valid.stop_gradient = True
count = paddle.sum((paddle.cast(valid, dtype=np.int32)))
count.stop_gradient = True
if count > 0:
loss = loss / count
reid_losses += loss
return reid_losses
...@@ -49,7 +49,7 @@ class JDEEmbeddingHead(nn.Layer): ...@@ -49,7 +49,7 @@ class JDEEmbeddingHead(nn.Layer):
JDEEmbeddingHead JDEEmbeddingHead
Args: Args:
num_classes(int): Number of classes. Only support one class tracking. num_classes(int): Number of classes. Only support one class tracking.
num_identifiers(int): Number of identifiers. num_identities(int): Number of identities.
anchor_levels(int): Number of anchor levels, same as FPN levels. anchor_levels(int): Number of anchor levels, same as FPN levels.
anchor_scales(int): Number of anchor scales on each FPN level. anchor_scales(int): Number of anchor scales on each FPN level.
embedding_dim(int): Embedding dimension. Default: 512. embedding_dim(int): Embedding dimension. Default: 512.
...@@ -60,7 +60,7 @@ class JDEEmbeddingHead(nn.Layer): ...@@ -60,7 +60,7 @@ class JDEEmbeddingHead(nn.Layer):
def __init__( def __init__(
self, self,
num_classes=1, num_classes=1,
num_identifiers=14455, # defined by dataset.total_identities when training num_identities=14455, # dataset.num_identities_dict[0]
anchor_levels=3, anchor_levels=3,
anchor_scales=4, anchor_scales=4,
embedding_dim=512, embedding_dim=512,
...@@ -68,7 +68,7 @@ class JDEEmbeddingHead(nn.Layer): ...@@ -68,7 +68,7 @@ class JDEEmbeddingHead(nn.Layer):
jde_loss='JDELoss'): jde_loss='JDELoss'):
super(JDEEmbeddingHead, self).__init__() super(JDEEmbeddingHead, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.num_identifiers = num_identifiers self.num_identities = num_identities
self.anchor_levels = anchor_levels self.anchor_levels = anchor_levels
self.anchor_scales = anchor_scales self.anchor_scales = anchor_scales
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
...@@ -76,7 +76,7 @@ class JDEEmbeddingHead(nn.Layer): ...@@ -76,7 +76,7 @@ class JDEEmbeddingHead(nn.Layer):
self.jde_loss = jde_loss self.jde_loss = jde_loss
self.emb_scale = math.sqrt(2) * math.log( self.emb_scale = math.sqrt(2) * math.log(
self.num_identifiers - 1) if self.num_identifiers > 1 else 1 self.num_identities - 1) if self.num_identities > 1 else 1
self.identify_outputs = [] self.identify_outputs = []
self.loss_params_cls = [] self.loss_params_cls = []
...@@ -106,7 +106,7 @@ class JDEEmbeddingHead(nn.Layer): ...@@ -106,7 +106,7 @@ class JDEEmbeddingHead(nn.Layer):
'classifier', 'classifier',
nn.Linear( nn.Linear(
self.embedding_dim, self.embedding_dim,
self.num_identifiers, self.num_identities,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=1., initializer=Normal( learning_rate=1., initializer=Normal(
mean=0.0, std=0.01)), mean=0.0, std=0.01)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册