From bb697394ed42f6da77fded52cea95dba3300ae90 Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Fri, 22 Oct 2021 14:16:54 +0800 Subject: [PATCH] [MOT]fix kitti cfgs and update kitti modelzoo (#4358) --- configs/mot/{kitticars => vehicle}/README.md | 0 .../mot/{kitticars => vehicle}/README_cn.md | 14 ++++++------- ...rmot_dla34_30e_1088x608_kitti_vehicle.yml} | 20 ++++++++++++++++--- deploy/python/mot_jde_infer.py | 16 ++++++++++----- deploy/python/tracker/jde_tracker.py | 6 ++++++ ppdet/engine/tracker.py | 12 ++++++----- ppdet/metrics/mot_metrics.py | 4 ++-- ppdet/modeling/mot/tracker/jde_tracker.py | 6 ++++++ 8 files changed, 56 insertions(+), 22 deletions(-) rename configs/mot/{kitticars => vehicle}/README.md (100%) rename configs/mot/{kitticars => vehicle}/README_cn.md (70%) rename configs/mot/{kitticars/fairmot_dla34_30e_1088x608_kitticars.yml => vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml} (59%) diff --git a/configs/mot/kitticars/README.md b/configs/mot/vehicle/README.md similarity index 100% rename from configs/mot/kitticars/README.md rename to configs/mot/vehicle/README.md diff --git a/configs/mot/kitticars/README_cn.md b/configs/mot/vehicle/README_cn.md similarity index 70% rename from configs/mot/kitticars/README_cn.md rename to configs/mot/vehicle/README_cn.md index 1b60678cb..4fc017ef0 100644 --- a/configs/mot/kitticars/README_cn.md +++ b/configs/mot/vehicle/README_cn.md @@ -15,7 +15,7 @@ | 骨干网络 | 输入尺寸 | MOTA | FPS | 下载链接 | 配置文件 | | :--------------| :------- | :-----: | :-----: | :------: | :----: | -| DLA-34 | 1088x608 | 53.9 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml) | +| DLA-34 | 1088x608 | 82.7 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitti_vehicle.pdparams) | [配置文件](./fairmot_dla34_30e_1088x608_kitti_vehicle.yml) | **注意:** FairMOT使用2个GPU进行训练,每个GPU上batch size为6,训练30个epoch。 @@ -25,36 +25,36 @@ ### 1. 训练 使用2个GPU通过如下命令一键式启动训练 ```bash -python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608_kitticars/ --gpus 0,1 tools/train.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml +python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608_kitti_vehicle/ --gpus 0,1 tools/train.py -c configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml ``` ### 2. 评估 使用单张GPU通过如下命令一键式启动评估 ```bash # 使用PaddleDetection发布的权重 -CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitti_vehicle.pdparams # 使用训练保存的checkpoint -CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=output/fairmot_dla34_30e_1088x608_kitticars/model_final.pdparams +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml -o weights=output/fairmot_dla34_30e_1088x608_kitti_vehicle/model_final.pdparams ``` ### 3. 预测 使用单个GPU通过如下命令预测一个视频,并保存为视频 ```bash # 预测一个视频 -CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams --video_file={your video name}.mp4 --save_videos +CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitti_vehicle.pdparams --video_file={your video name}.mp4 --save_videos ``` **注意:** 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`。 ### 4. 导出预测模型 ```bash -CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/kitti_vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitti_vehicle.pdparams ``` ### 5. 用导出的模型基于Python去预测 ```bash -python deploy/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608_kitticars --video_file={your video name}.mp4 --device=GPU --save_mot_txts +python deploy/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608_kitti_vehicle --video_file={your video name}.mp4 --device=GPU --save_mot_txts ``` **注意:** 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 diff --git a/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml b/configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml similarity index 59% rename from configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml rename to configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml index a103cba6c..6f3c9e97d 100755 --- a/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml +++ b/configs/mot/vehicle/fairmot_dla34_30e_1088x608_kitti_vehicle.yml @@ -3,13 +3,13 @@ _BASE_: [ ] metric: KITTI -weights: output/fairmot_dla34_30e_1088x608_kitticars/model_final +weights: output/fairmot_dla34_30e_1088x608_kitti_vehicle/model_final # for MOT training TrainDataset: !MOTDataSet dataset_dir: dataset/mot - image_lists: ['kitticars.train'] + image_lists: ['kitti_vehicle.train'] data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide'] # for MOT evaluation @@ -17,7 +17,7 @@ TrainDataset: EvalMOTDataset: !MOTImageFolder dataset_dir: dataset/mot - data_root: kitticars/images/test + data_root: kitti_vehicle/images/train keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT # for MOT video inference @@ -25,3 +25,17 @@ TestMOTDataset: !MOTImageFolder dataset_dir: dataset/mot keep_ori_im: True # set True if save visualization images or video + +# model config +FairMOT: + detector: CenterNet + reid: FairMOTEmbeddingHead + loss: FairMOTLoss + tracker: JDETracker + +JDETracker: + min_box_area: 200 + vertical_ratio: 0 # no need to filter bboxes according to w/h + conf_thres: 0.4 + tracked_thresh: 0.4 + metric_type: cosine diff --git a/deploy/python/mot_jde_infer.py b/deploy/python/mot_jde_infer.py index de73ca077..95c1f52ed 100644 --- a/deploy/python/mot_jde_infer.py +++ b/deploy/python/mot_jde_infer.py @@ -81,10 +81,14 @@ class JDE_Detector(Detector): assert batch_size == 1, "The JDE Detector only supports batch size=1 now" assert pred_config.tracker, "Tracking model should have tracker" tp = pred_config.tracker + min_box_area = tp['min_box_area'] if 'min_box_area' in tp else 200 + vertical_ratio = tp['vertical_ratio'] if 'vertical_ratio' in tp else 1.6 conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0. tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7 metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean' self.tracker = JDETracker( + min_box_area=min_box_area, + vertical_ratio=vertical_ratio, conf_thres=conf_thres, tracked_thresh=tracked_thresh, metric_type=metric_type) @@ -100,11 +104,13 @@ class JDE_Detector(Detector): tid = t.track_id tscore = t.score if tscore < threshold: continue - vertical = tlwh[2] / tlwh[3] > 1.6 - if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical: - online_tlwhs.append(tlwh) - online_ids.append(tid) - online_scores.append(tscore) + if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue + if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ + 3] > self.tracker.vertical_ratio: + continue + online_tlwhs.append(tlwh) + online_ids.append(tid) + online_scores.append(tscore) return online_tlwhs, online_scores, online_ids def predict(self, image_list, threshold=0.5, warmup=0, repeats=1): diff --git a/deploy/python/tracker/jde_tracker.py b/deploy/python/tracker/jde_tracker.py index b44e8700d..57e87da84 100644 --- a/deploy/python/tracker/jde_tracker.py +++ b/deploy/python/tracker/jde_tracker.py @@ -31,6 +31,9 @@ class JDETracker(object): det_thresh (float): threshold of detection score track_buffer (int): buffer for tracker min_box_area (int): min box area to filter out low quality boxes + vertical_ratio (float): w/h, the vertical ratio of the bbox to filter + bad results, set 1.6 default for pedestrian tracking. If set -1 + means no need to filter bboxes. tracked_thresh (float): linear assignment threshold of tracked stracks and detections r_tracked_thresh (float): linear assignment threshold of @@ -47,6 +50,7 @@ class JDETracker(object): det_thresh=0.3, track_buffer=30, min_box_area=200, + vertical_ratio=1.6, tracked_thresh=0.7, r_tracked_thresh=0.5, unconfirmed_thresh=0.7, @@ -56,6 +60,8 @@ class JDETracker(object): self.det_thresh = det_thresh self.track_buffer = track_buffer self.min_box_area = min_box_area + self.vertical_ratio = vertical_ratio + self.tracked_thresh = tracked_thresh self.r_tracked_thresh = r_tracked_thresh self.unconfirmed_thresh = unconfirmed_thresh diff --git a/ppdet/engine/tracker.py b/ppdet/engine/tracker.py index 0d4f98b68..f5058e129 100644 --- a/ppdet/engine/tracker.py +++ b/ppdet/engine/tracker.py @@ -144,11 +144,13 @@ class Tracker(object): tid = t.track_id tscore = t.score if tscore < draw_threshold: continue - vertical = tlwh[2] / tlwh[3] > 1.6 - if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical: - online_tlwhs.append(tlwh) - online_ids.append(tid) - online_scores.append(tscore) + if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue + if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ + 3] > tracker.vertical_ratio: + continue + online_tlwhs.append(tlwh) + online_ids.append(tid) + online_scores.append(tscore) timer.toc() # save results diff --git a/ppdet/metrics/mot_metrics.py b/ppdet/metrics/mot_metrics.py index e70c0bd31..88ce96642 100644 --- a/ppdet/metrics/mot_metrics.py +++ b/ppdet/metrics/mot_metrics.py @@ -375,7 +375,7 @@ class KITTIEvaluation(object): # get number of sequences and # get number of frames per sequence from test mapping # (created while extracting the benchmark) - self.gt_path = os.path.join(gt_path, "label_02") + self.gt_path = os.path.join(gt_path, "../labels") self.n_frames = n_frames self.sequence_name = seqs self.n_sequences = n_sequences @@ -1177,7 +1177,7 @@ class KITTIMOTMetric(Metric): assert data_type == 'kitti', "data_type should 'kitti'" self.result_root = result_root self.gt_path = data_root - gt_path = '{}/label_02/{}.txt'.format(data_root, seq) + gt_path = '{}/../labels/{}.txt'.format(data_root, seq) gt = open(gt_path, "r") max_frame = 0 for line in gt: diff --git a/ppdet/modeling/mot/tracker/jde_tracker.py b/ppdet/modeling/mot/tracker/jde_tracker.py index 9c9007e91..aa232da02 100644 --- a/ppdet/modeling/mot/tracker/jde_tracker.py +++ b/ppdet/modeling/mot/tracker/jde_tracker.py @@ -39,6 +39,9 @@ class JDETracker(object): det_thresh (float): threshold of detection score track_buffer (int): buffer for tracker min_box_area (int): min box area to filter out low quality boxes + vertical_ratio (float): w/h, the vertical ratio of the bbox to filter + bad results, set 1.6 default for pedestrian tracking. If set -1 + means no need to filter bboxes. tracked_thresh (float): linear assignment threshold of tracked stracks and detections r_tracked_thresh (float): linear assignment threshold of @@ -55,6 +58,7 @@ class JDETracker(object): det_thresh=0.3, track_buffer=30, min_box_area=200, + vertical_ratio=1.6, tracked_thresh=0.7, r_tracked_thresh=0.5, unconfirmed_thresh=0.7, @@ -64,6 +68,8 @@ class JDETracker(object): self.det_thresh = det_thresh self.track_buffer = track_buffer self.min_box_area = min_box_area + self.vertical_ratio = vertical_ratio + self.tracked_thresh = tracked_thresh self.r_tracked_thresh = r_tracked_thresh self.unconfirmed_thresh = unconfirmed_thresh -- GitLab