未验证 提交 bb697394 编写于 作者: F Feng Ni 提交者: GitHub

[MOT]fix kitti cfgs and update kitti modelzoo (#4358)

上级 0adb3cbc
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
| 骨干网络 | 输入尺寸 | MOTA | FPS | 下载链接 | 配置文件 | | 骨干网络 | 输入尺寸 | MOTA | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :-----: | :-----: | :------: | :----: | | :--------------| :------- | :-----: | :-----: | :------: | :----: |
| DLA-34 | 1088x608 | 53.9 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml) | | DLA-34 | 1088x608 | 82.7 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitti_vehicle.pdparams) | [配置文件](./fairmot_dla34_30e_1088x608_kitti_vehicle.yml) |
**注意:** **注意:**
FairMOT使用2个GPU进行训练,每个GPU上batch size为6,训练30个epoch。 FairMOT使用2个GPU进行训练,每个GPU上batch size为6,训练30个epoch。
...@@ -25,36 +25,36 @@ ...@@ -25,36 +25,36 @@
### 1. 训练 ### 1. 训练
使用2个GPU通过如下命令一键式启动训练 使用2个GPU通过如下命令一键式启动训练
```bash ```bash
python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608_kitticars/ --gpus 0,1 tools/train.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608_kitti_vehicle/ --gpus 0,1 tools/train.py -c configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml
``` ```
### 2. 评估 ### 2. 评估
使用单张GPU通过如下命令一键式启动评估 使用单张GPU通过如下命令一键式启动评估
```bash ```bash
# 使用PaddleDetection发布的权重 # 使用PaddleDetection发布的权重
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitti_vehicle.pdparams
# 使用训练保存的checkpoint # 使用训练保存的checkpoint
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=output/fairmot_dla34_30e_1088x608_kitticars/model_final.pdparams CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml -o weights=output/fairmot_dla34_30e_1088x608_kitti_vehicle/model_final.pdparams
``` ```
### 3. 预测 ### 3. 预测
使用单个GPU通过如下命令预测一个视频,并保存为视频 使用单个GPU通过如下命令预测一个视频,并保存为视频
```bash ```bash
# 预测一个视频 # 预测一个视频
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams --video_file={your video name}.mp4 --save_videos CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitti_vehicle.pdparams --video_file={your video name}.mp4 --save_videos
``` ```
**注意:** **注意:**
请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg` 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`
### 4. 导出预测模型 ### 4. 导出预测模型
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/kitti_vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitti_vehicle.pdparams
``` ```
### 5. 用导出的模型基于Python去预测 ### 5. 用导出的模型基于Python去预测
```bash ```bash
python deploy/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608_kitticars --video_file={your video name}.mp4 --device=GPU --save_mot_txts python deploy/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608_kitti_vehicle --video_file={your video name}.mp4 --device=GPU --save_mot_txts
``` ```
**注意:** **注意:**
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
......
...@@ -3,13 +3,13 @@ _BASE_: [ ...@@ -3,13 +3,13 @@ _BASE_: [
] ]
metric: KITTI metric: KITTI
weights: output/fairmot_dla34_30e_1088x608_kitticars/model_final weights: output/fairmot_dla34_30e_1088x608_kitti_vehicle/model_final
# for MOT training # for MOT training
TrainDataset: TrainDataset:
!MOTDataSet !MOTDataSet
dataset_dir: dataset/mot dataset_dir: dataset/mot
image_lists: ['kitticars.train'] image_lists: ['kitti_vehicle.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide'] data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
# for MOT evaluation # for MOT evaluation
...@@ -17,7 +17,7 @@ TrainDataset: ...@@ -17,7 +17,7 @@ TrainDataset:
EvalMOTDataset: EvalMOTDataset:
!MOTImageFolder !MOTImageFolder
dataset_dir: dataset/mot dataset_dir: dataset/mot
data_root: kitticars/images/test data_root: kitti_vehicle/images/train
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
# for MOT video inference # for MOT video inference
...@@ -25,3 +25,17 @@ TestMOTDataset: ...@@ -25,3 +25,17 @@ TestMOTDataset:
!MOTImageFolder !MOTImageFolder
dataset_dir: dataset/mot dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video keep_ori_im: True # set True if save visualization images or video
# model config
FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker
JDETracker:
min_box_area: 200
vertical_ratio: 0 # no need to filter bboxes according to w/h
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
...@@ -81,10 +81,14 @@ class JDE_Detector(Detector): ...@@ -81,10 +81,14 @@ class JDE_Detector(Detector):
assert batch_size == 1, "The JDE Detector only supports batch size=1 now" assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
assert pred_config.tracker, "Tracking model should have tracker" assert pred_config.tracker, "Tracking model should have tracker"
tp = pred_config.tracker tp = pred_config.tracker
min_box_area = tp['min_box_area'] if 'min_box_area' in tp else 200
vertical_ratio = tp['vertical_ratio'] if 'vertical_ratio' in tp else 1.6
conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0. conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0.
tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7 tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7
metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean' metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean'
self.tracker = JDETracker( self.tracker = JDETracker(
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
conf_thres=conf_thres, conf_thres=conf_thres,
tracked_thresh=tracked_thresh, tracked_thresh=tracked_thresh,
metric_type=metric_type) metric_type=metric_type)
...@@ -100,8 +104,10 @@ class JDE_Detector(Detector): ...@@ -100,8 +104,10 @@ class JDE_Detector(Detector):
tid = t.track_id tid = t.track_id
tscore = t.score tscore = t.score
if tscore < threshold: continue if tscore < threshold: continue
vertical = tlwh[2] / tlwh[3] > 1.6 if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical: if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh) online_tlwhs.append(tlwh)
online_ids.append(tid) online_ids.append(tid)
online_scores.append(tscore) online_scores.append(tscore)
......
...@@ -31,6 +31,9 @@ class JDETracker(object): ...@@ -31,6 +31,9 @@ class JDETracker(object):
det_thresh (float): threshold of detection score det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set -1
means no need to filter bboxes.
tracked_thresh (float): linear assignment threshold of tracked tracked_thresh (float): linear assignment threshold of tracked
stracks and detections stracks and detections
r_tracked_thresh (float): linear assignment threshold of r_tracked_thresh (float): linear assignment threshold of
...@@ -47,6 +50,7 @@ class JDETracker(object): ...@@ -47,6 +50,7 @@ class JDETracker(object):
det_thresh=0.3, det_thresh=0.3,
track_buffer=30, track_buffer=30,
min_box_area=200, min_box_area=200,
vertical_ratio=1.6,
tracked_thresh=0.7, tracked_thresh=0.7,
r_tracked_thresh=0.5, r_tracked_thresh=0.5,
unconfirmed_thresh=0.7, unconfirmed_thresh=0.7,
...@@ -56,6 +60,8 @@ class JDETracker(object): ...@@ -56,6 +60,8 @@ class JDETracker(object):
self.det_thresh = det_thresh self.det_thresh = det_thresh
self.track_buffer = track_buffer self.track_buffer = track_buffer
self.min_box_area = min_box_area self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.tracked_thresh = tracked_thresh self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh self.unconfirmed_thresh = unconfirmed_thresh
......
...@@ -144,8 +144,10 @@ class Tracker(object): ...@@ -144,8 +144,10 @@ class Tracker(object):
tid = t.track_id tid = t.track_id
tscore = t.score tscore = t.score
if tscore < draw_threshold: continue if tscore < draw_threshold: continue
vertical = tlwh[2] / tlwh[3] > 1.6 if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical: if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh) online_tlwhs.append(tlwh)
online_ids.append(tid) online_ids.append(tid)
online_scores.append(tscore) online_scores.append(tscore)
......
...@@ -375,7 +375,7 @@ class KITTIEvaluation(object): ...@@ -375,7 +375,7 @@ class KITTIEvaluation(object):
# get number of sequences and # get number of sequences and
# get number of frames per sequence from test mapping # get number of frames per sequence from test mapping
# (created while extracting the benchmark) # (created while extracting the benchmark)
self.gt_path = os.path.join(gt_path, "label_02") self.gt_path = os.path.join(gt_path, "../labels")
self.n_frames = n_frames self.n_frames = n_frames
self.sequence_name = seqs self.sequence_name = seqs
self.n_sequences = n_sequences self.n_sequences = n_sequences
...@@ -1177,7 +1177,7 @@ class KITTIMOTMetric(Metric): ...@@ -1177,7 +1177,7 @@ class KITTIMOTMetric(Metric):
assert data_type == 'kitti', "data_type should 'kitti'" assert data_type == 'kitti', "data_type should 'kitti'"
self.result_root = result_root self.result_root = result_root
self.gt_path = data_root self.gt_path = data_root
gt_path = '{}/label_02/{}.txt'.format(data_root, seq) gt_path = '{}/../labels/{}.txt'.format(data_root, seq)
gt = open(gt_path, "r") gt = open(gt_path, "r")
max_frame = 0 max_frame = 0
for line in gt: for line in gt:
......
...@@ -39,6 +39,9 @@ class JDETracker(object): ...@@ -39,6 +39,9 @@ class JDETracker(object):
det_thresh (float): threshold of detection score det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set -1
means no need to filter bboxes.
tracked_thresh (float): linear assignment threshold of tracked tracked_thresh (float): linear assignment threshold of tracked
stracks and detections stracks and detections
r_tracked_thresh (float): linear assignment threshold of r_tracked_thresh (float): linear assignment threshold of
...@@ -55,6 +58,7 @@ class JDETracker(object): ...@@ -55,6 +58,7 @@ class JDETracker(object):
det_thresh=0.3, det_thresh=0.3,
track_buffer=30, track_buffer=30,
min_box_area=200, min_box_area=200,
vertical_ratio=1.6,
tracked_thresh=0.7, tracked_thresh=0.7,
r_tracked_thresh=0.5, r_tracked_thresh=0.5,
unconfirmed_thresh=0.7, unconfirmed_thresh=0.7,
...@@ -64,6 +68,8 @@ class JDETracker(object): ...@@ -64,6 +68,8 @@ class JDETracker(object):
self.det_thresh = det_thresh self.det_thresh = det_thresh
self.track_buffer = track_buffer self.track_buffer = track_buffer
self.min_box_area = min_box_area self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.tracked_thresh = tracked_thresh self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh self.unconfirmed_thresh = unconfirmed_thresh
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册