diff --git a/configs/mot/deepsort/README.md b/configs/mot/deepsort/README.md index b479849387094fb24052d3ac08ac09f4013ec0fa..283e21ce1aeb5f8fe03da448e34094485d137df1 100644 --- a/configs/mot/deepsort/README.md +++ b/configs/mot/deepsort/README.md @@ -55,7 +55,7 @@ If you use a stronger detection model, you can get better results. Each txt is t - `width,height` is the pixel width and height - `conf` is the object score with default value `1` (the results had been filtered out according to the detection score threshold) -- 2.Load the detection model and the ReID model at the same time. Here, the JDE version of YOLOv3 is selected. For more detail of configuration, see `configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml`. +- 2. Load the detection model and the ReID model at the same time. Here, the JDE version of YOLOv3 is selected. For more detail of configuration, see `configs/mot/deepsort/_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml`. Load other general detection model, you can refer to `configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml`. ## Getting Start @@ -65,40 +65,60 @@ If you use a stronger detection model, you can get better results. Each txt is t # Load the result file and ReID model to get the tracking result CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml --det_results_dir {your detection results} -# Load the detection model and ReID model to get the tracking results -CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml +# Load JDE YOLOv3 detector and ReID model to get the tracking results +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml + +# or Load genernal YOLOv3 detector and ReID model to get the tracking results +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --scaled=True ``` +**Notes:** +JDE YOLOv3 pedestrian detector is trained with the same MOT dataset as JDE and FairMOT. In addition, the biggest difference between this model and general YOLOv3 model is that JDEBBoxPostProcess post-processing, and the output coordinates are not scaled back to the original image. +General YOLOv3 pedestrian detector is not trained on MOT dataset, so the performance is lower. But the output coordinates are scaled back to the original image. + `--scaled` means whether the coords after detector outputs are scaled back to the original image, False in JDE YOLOv3, True in general detector. ### 2. Inference Inference a vidoe on single GPU with following command: ```bash -# inference on video and save a video -CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --save_videos +# load JDE YOLOv3 pedestrian detector and ReID model to get tracking results +CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --save_videos + +# or load general YOLOv3 pedestrian detector and ReID model to get tracking results +CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --scaled=True --save_videos ``` **Notes:** Please make sure that [ffmpeg](https://ffmpeg.org/ffmpeg.html) is installed first, on Linux(Ubuntu) platform you can directly install it by the following command:`apt-get update && apt-get install -y ffmpeg`. + `--scaled` means whether the coords after detector outputs are scaled back to the original image, False in JDE YOLOv3, True in general detector. ### 3. Export model ```bash -1.export detection model +# 1.export detection model +# export JDE YOLOv3 pedestrian detector CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams -2.export ReID model -CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams -or +# or export general YOLOv3 pedestrian detector +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/pedestrian/pedestrian_yolov3_darknet.yml -o weights=https://paddledet.bj.bcebos.com/models/pedestrian_yolov3_darknet.pdparams + + +# 2. export ReID model CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams ``` ### 4. Using exported model for python inference ```bash -python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts +# using exported JDE YOLOv3 pedestrian detector +python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts + +# or using exported general YOLOv3 pedestrian detector +python deploy/python/mot_sde_infer.py --model_dir=output_inference/pedestrian_yolov3_darknet/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --scaled=True --save_mot_txts ``` **Notes:** -The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts` to save the txt result file, or `--save_images` to save the visualization images. +The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts`(save a txt for every video) or `--save_mot_txt_per_img`(save a txt for every image) to save the txt result file, or `--save_images` to save the visualization images. + `--scaled` means whether the coords after detector outputs are scaled back to the original image, False in JDE YOLOv3, True in general detector. + ## Citations ``` diff --git a/configs/mot/deepsort/README_cn.md b/configs/mot/deepsort/README_cn.md index 82e128471f2df7355232c00136e05d9868545e05..548c945f493f27d613c201187e37bbf1777a7fa7 100644 --- a/configs/mot/deepsort/README_cn.md +++ b/configs/mot/deepsort/README_cn.md @@ -56,7 +56,7 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip - `width,height`是真实的像素宽高 - `conf`是目标得分设置为`1`(已经按检测的得分阈值筛选出的检测结果) -- 第2种方式是同时加载检测模型和ReID模型,此处选用JDE版本的YOLOv3,具体配置见`configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml` +- 第2种方式是同时加载检测模型和ReID模型,此处选用JDE版本的YOLOv3,具体配置见`configs/mot/deepsort/_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml`。加载其他通用检测模型可参照`configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml`进行修改。 ## 快速开始 @@ -66,42 +66,61 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip # 加载检测结果文件和ReID模型,得到跟踪结果 CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml --det_results_dir {your detection results} -# 加载检测模型和ReID模型,得到跟踪结果 -CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml +# 加载JDE YOLOv3行人检测模型和ReID模型,得到跟踪结果 +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml + +# 或者加载普通YOLOv3行人检测模型和ReID模型,得到跟踪结果 +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --scaled=True ``` +**注意:** + JDE YOLOv3行人检测模型是和JDE和FairMOT使用同样的MOT数据集训练的,这个模型与普通YOLOv3模型最大的区别是使用了JDEBBoxPostProcess后处理,结果输出坐标没有缩放回原图。 + 普通YOLOv3行人检测模型不是用MOT数据集训练的,所以精度效果更低, 其模型输出坐标是缩放回原图的。 + `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。 ### 2. 预测 使用单个GPU通过如下命令预测一个视频,并保存为视频 ```bash -# 加载检测模型和ReID模型,得到跟踪结果 -CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --save_videos +# 加载JDE YOLOv3行人检测模型和ReID模型,并保存为视频 +CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --save_videos + +# 或者加载普通YOLOv3行人检测模型和ReID模型,并保存为视频 +CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --scaled=True --save_videos ``` **注意:** 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`。 + `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。 + ### 3. 导出预测模型 ```bash -1.先导出检测模型 +# 1.先导出检测模型 +# 导出JDE YOLOv3行人检测模型 CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams -2.再导出ReID模型 -CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams +# 或导出普通YOLOv3行人检测模型 +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/pedestrian/pedestrian_yolov3_darknet.yml -o weights=https://paddledet.bj.bcebos.com/models/pedestrian_yolov3_darknet.pdparams + -或 +# 2.再导出ReID模型 CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams ``` ### 4. 用导出的模型基于Python去预测 ```bash -python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts +# 用导出JDE YOLOv3行人检测模型 +python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts + +# 或用导出的普通yolov3行人检测模型 +python deploy/python/mot_sde_infer.py --model_dir=output_inference/pedestrian_yolov3_darknet/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --scaled=True --save_mot_txts ``` **注意:** - 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 + 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_mot_txt_per_img`(对每张图片保存一个txt)表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 + `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。 ## 引用 ``` diff --git a/configs/mot/deepsort/_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml b/configs/mot/deepsort/_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml new file mode 100644 index 0000000000000000000000000000000000000000..fbab9a35e23c89411d3f9303569fb79a53db55ab --- /dev/null +++ b/configs/mot/deepsort/_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml @@ -0,0 +1,59 @@ +architecture: DeepSORT +pretrain_weights: None + +DeepSORT: + detector: YOLOv3 # JDE version + reid: PCBPyramid + tracker: DeepSORTTracker + +PCBPyramid: + num_conv_out_channels: 128 + num_classes: 751 + +DeepSORTTracker: + budget: 100 + max_age: 70 + n_init: 3 + metric_type: cosine + matching_threshold: 0.2 + max_iou_distance: 0.9 + motion: KalmanFilter + + +# JDE version YOLOv3 detector for MOT dataset. +# The most obvious difference is JDEBBoxPostProcess and the bboxes coordinates +# output are not scaled to the original image. +YOLOv3: + backbone: DarkNet + neck: YOLOv3FPN + yolo_head: YOLOv3Head + post_process: JDEBBoxPostProcess + +DarkNet: + depth: 53 + return_idx: [2, 3, 4] + freeze_norm: True + +YOLOv3FPN: + freeze_norm: True + +YOLOv3Head: + anchors: [[128,384], [180,540], [256,640], [512,640], + [32,96], [45,135], [64,192], [90,271], + [8,24], [11,34], [16,48], [23,68]] + anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + loss: JDEDetectionLoss + +JDEBBoxPostProcess: + decode: + name: JDEBox + conf_thresh: 0.3 + downsample_ratio: 32 + nms: + name: MultiClassNMS + keep_top_k: 500 + score_threshold: 0.01 + nms_threshold: 0.5 + nms_top_k: 2000 + normalized: true + return_idx: false diff --git a/configs/mot/deepsort/_base_/deepsort_reader_1088x608.yml b/configs/mot/deepsort/_base_/deepsort_reader_1088x608.yml index 1bbc28fd870aaf0a1f000f2754ea6101f07e810a..6ab950aa94e0ea203ea6184d7e3910164ef85993 100644 --- a/configs/mot/deepsort/_base_/deepsort_reader_1088x608.yml +++ b/configs/mot/deepsort/_base_/deepsort_reader_1088x608.yml @@ -1,3 +1,7 @@ +# DeepSORT does not need to train on MOT dataset, only used for evaluation. +# MOT dataset needs to be trained on the detector(like YOLOv3) only using bboxes. +# And gt IDs don't need to be trained. + EvalMOTReader: sample_transforms: - Decode: {} diff --git a/configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml b/configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml index 442ced2bbeee7e1078e801658ed4faca0a7d3b2c..1f9ad5234eb6d397809247143ead41aa2e89ea81 100644 --- a/configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml +++ b/configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml @@ -2,55 +2,57 @@ architecture: DeepSORT pretrain_weights: None DeepSORT: - detector: YOLOv3 # JDE version + detector: YOLOv3 # General version reid: PCBPyramid tracker: DeepSORTTracker -# JDE version for MOT dataset +PCBPyramid: + num_conv_out_channels: 128 + num_classes: 751 + +DeepSORTTracker: + budget: 100 + max_age: 70 + n_init: 3 + metric_type: cosine + matching_threshold: 0.2 + max_iou_distance: 0.9 + motion: KalmanFilter + + +# General version YOLOv3 +# Using BBoxPostProcess and the bboxes output are scaled to the original image. YOLOv3: backbone: DarkNet neck: YOLOv3FPN yolo_head: YOLOv3Head - post_process: JDEBBoxPostProcess + post_process: BBoxPostProcess + +norm_type: sync_bn DarkNet: depth: 53 return_idx: [2, 3, 4] - freeze_norm: True -YOLOv3FPN: - freeze_norm: True +# use default config +# YOLOv3FPN: YOLOv3Head: - anchors: [[128,384], [180,540], [256,640], [512,640], - [32,96], [45,135], [64,192], [90,271], - [8,24], [11,34], [16,48], [23,68]] - anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] - loss: JDEDetectionLoss + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + loss: YOLOv3Loss -JDEBBoxPostProcess: +BBoxPostProcess: decode: - name: JDEBox - conf_thresh: 0.3 + name: YOLOBox + conf_thresh: 0.005 downsample_ratio: 32 + clip_bbox: true nms: name: MultiClassNMS - keep_top_k: 500 + keep_top_k: 100 score_threshold: 0.01 - nms_threshold: 0.5 - nms_top_k: 2000 - normalized: true - return_idx: false - -PCBPyramid: - num_conv_out_channels: 128 - num_classes: 751 - -DeepSORTTracker: - budget: 100 - max_age: 70 - n_init: 3 - metric_type: cosine - matching_threshold: 0.2 - max_iou_distance: 0.9 - motion: KalmanFilter + nms_threshold: 0.45 + nms_top_k: 1000 diff --git a/configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml b/configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml new file mode 100644 index 0000000000000000000000000000000000000000..6ae498ccabd6ea2cca43fc942e03ac70cd2baf7c --- /dev/null +++ b/configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml @@ -0,0 +1,29 @@ +_BASE_: [ + '../../datasets/mot.yml', + '../../runtime.yml', + '_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml', + '_base_/deepsort_reader_1088x608.yml', +] + +EvalMOTDataset: + !MOTImageFolder + dataset_dir: dataset/mot + data_root: MOT16/images/train + keep_ori_im: True # set as True in DeepSORT + +det_weights: https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams +reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams + +DeepSORT: + detector: YOLOv3 + reid: PCBPyramid + tracker: DeepSORTTracker + +# JDE version YOLOv3 detector for MOT dataset. +# The most obvious difference is JDEBBoxPostProcess and the bboxes coordinates +# output are not scaled to the original image. +YOLOv3: + backbone: DarkNet + neck: YOLOv3FPN + yolo_head: YOLOv3Head + post_process: JDEBBoxPostProcess diff --git a/configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml b/configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml index 58aa29ed1832e8a24342fdd3d4ba44171760e6a6..fe8525da0256170fcbb56a3d71208126191a3d74 100644 --- a/configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml +++ b/configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml @@ -11,7 +11,7 @@ EvalMOTDataset: data_root: MOT16/images/train keep_ori_im: True # set as True in DeepSORT -det_weights: https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams +det_weights: https://paddledet.bj.bcebos.com/models/pedestrian_yolov3_darknet.pdparams reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams DeepSORT: @@ -19,9 +19,10 @@ DeepSORT: reid: PCBPyramid tracker: DeepSORTTracker -# JDE version for MOT dataset +# General version YOLOv3 +# Using BBoxPostProcess and the bboxes output are scaled to the original image. YOLOv3: backbone: DarkNet neck: YOLOv3FPN yolo_head: YOLOv3Head - post_process: JDEBBoxPostProcess + post_process: BBoxPostProcess diff --git a/deploy/python/mot_jde_infer.py b/deploy/python/mot_jde_infer.py index 4d6b33a219e832f201554213f5f2f9f8325e990c..0ff6be5778a3fa41bb99e94ee72773c5a705a964 100644 --- a/deploy/python/mot_jde_infer.py +++ b/deploy/python/mot_jde_infer.py @@ -92,7 +92,9 @@ class JDE_Detector(Detector): def postprocess(self, pred_dets, pred_embs, threshold): online_targets = self.tracker.update(pred_dets, pred_embs) if online_targets == []: - return [pred_dets[0][:4]], [pred_dets[0][4]], [1] + # First few frames, the model may have no tracking results but have + # detection results,use the detection results instead, and set id -1. + return [pred_dets[0][:4]], [pred_dets[0][4]], [-1] online_tlwhs, online_ids = [], [] online_scores = [] for t in online_targets: @@ -162,8 +164,6 @@ def write_mot_results(filename, results, data_type='mot'): if data_type == 'kitti': frame_id -= 1 for tlwh, score, track_id in zip(tlwhs, tscores, track_ids): - if track_id < 0: - continue x1, y1, w, h = tlwh x2, y2 = x1 + w, y1 + h line = save_format.format( @@ -254,6 +254,15 @@ def predict_video(detector, camera_id): os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) else: writer.write(im) + + if FLAGS.save_mot_txt_per_img: + save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + result_filename = os.path.join(save_dir, + '{:05d}.txt'.format(frame_id)) + write_mot_results(result_filename, [results[-1]]) + frame_id += 1 print('detect frame:%d' % (frame_id)) if camera_id != -1: diff --git a/deploy/python/mot_sde_infer.py b/deploy/python/mot_sde_infer.py index 923fa62cbff58e2c628e4c80a9f06f2e8b054ee5..dd1dbe9280309b83ee3a2b3a88f4582b836610e0 100644 --- a/deploy/python/mot_sde_infer.py +++ b/deploy/python/mot_sde_infer.py @@ -135,20 +135,29 @@ class SDE_Detector(Detector): enable_mkldnn=enable_mkldnn) assert batch_size == 1, "The JDE Detector only supports batch size=1 now" - def postprocess(self, boxes, input_shape, im_shape, scale_factor, - threshold): - pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape, + def postprocess(self, boxes, input_shape, im_shape, scale_factor, threshold, + scaled): + if not scaled: + # postprocess output of jde yolov3 detector + pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape, + scale_factor) + pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape, scale_factor) - pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape, scale_factor) + else: + # postprocess output of general detector + pred_bboxes = boxes[:, 2:] + pred_scores = boxes[:, 1:2] keep_mask = pred_scores[:, 0] >= threshold return pred_bboxes[keep_mask], pred_scores[keep_mask] - def predict(self, image, threshold=0.5, warmup=0, repeats=1): + def predict(self, image, scaled, threshold=0.5, warmup=0, repeats=1): ''' Args: image (np.ndarray): image numpy data threshold (float): threshold of predicted box' score + scaled (bool): whether the coords after detector outputs are scaled, + default False in jde yolov3, set True in general detector. Returns: pred_bboxes, pred_scores (np.ndarray) ''' @@ -181,7 +190,7 @@ class SDE_Detector(Detector): im_shape = inputs['im_shape'] scale_factor = inputs['scale_factor'] pred_bboxes, pred_scores = self.postprocess( - boxes, input_shape, im_shape, scale_factor, threshold) + boxes, input_shape, im_shape, scale_factor, threshold, scaled) self.det_times.postprocess_time_s.end() self.det_times.img_num += 1 return pred_bboxes, pred_scores @@ -302,14 +311,14 @@ def predict_image(detector, reid_model, image_list): frame = cv2.imread(img_file) if FLAGS.run_benchmark: pred_bboxes, pred_scores = detector.predict( - [frame], FLAGS.threshold, warmup=10, repeats=10) + [frame], FLAGS.scaled, FLAGS.threshold, warmup=10, repeats=10) cm, gm, gu = get_current_memory_mb() detector.cpu_mem += cm detector.gpu_mem += gm detector.gpu_util += gu print('Test iter {}, file name:{}'.format(i, img_file)) else: - pred_bboxes, pred_scores = detector.predict([frame], + pred_bboxes, pred_scores = detector.predict([frame], FLAGS.scaled, FLAGS.threshold) # process @@ -319,7 +328,8 @@ def predict_image(detector, reid_model, image_list): axis=1) crops, pred_scores = reid_model.get_crops( pred_bboxes, frame, pred_scores, w=64, h=192) - + if len(crops) == 0: + continue if FLAGS.run_benchmark: online_tlwhs, online_scores, online_ids = reid_model.predict( crops, bbox_tlwh, pred_scores, warmup=10, repeats=10) @@ -366,7 +376,8 @@ def predict_video(detector, reid_model, camera_id): if not ret: break timer.tic() - pred_bboxes, pred_scores = detector.predict([frame], FLAGS.threshold) + pred_bboxes, pred_scores = detector.predict([frame], FLAGS.scaled, + FLAGS.threshold) timer.toc() bbox_tlwh = np.concatenate( (pred_bboxes[:, 0:2], @@ -374,7 +385,8 @@ def predict_video(detector, reid_model, camera_id): axis=1) crops, pred_scores = reid_model.get_crops( pred_bboxes, frame, pred_scores, w=64, h=192) - + if len(crops) == 0: + continue online_tlwhs, online_scores, online_ids = reid_model.predict( crops, bbox_tlwh, pred_scores) @@ -395,6 +407,23 @@ def predict_video(detector, reid_model, camera_id): os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) else: writer.write(im) + + if FLAGS.save_mot_txt_per_img: + save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + result_filename = os.path.join(save_dir, + '{:05d}.txt'.format(frame_id)) + # First few frames, the model may have no tracking results but have + # detection results,use the detection results instead, and set id -1. + if results[-1][2] == []: + tlwhs = [tlwh for tlwh in bbox_tlwh] + scores = [score[0] for score in pred_scores] + result = (frame_id + 1, tlwhs, scores, [-1] * len(tlwhs)) + else: + result = results[-1] + write_mot_results(result_filename, [result]) + frame_id += 1 print('detect frame:%d' % (frame_id)) if camera_id != -1: diff --git a/deploy/python/utils.py b/deploy/python/utils.py index c35364c01765ea842a3bc0d16eaf089f674e717e..1a7fe211bbd4ff9caec88a45bea8a0f74e8b2abe 100644 --- a/deploy/python/utils.py +++ b/deploy/python/utils.py @@ -108,6 +108,16 @@ def argsparser(): '--save_mot_txts', action='store_true', help='Save tracking results (txt).') + parser.add_argument( + '--save_mot_txt_per_img', + action='store_true', + help='Save tracking results (txt) for each image.') + parser.add_argument( + '--scaled', + type=bool, + default=False, + help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 " + "True in general detector.") parser.add_argument( "--reid_model_dir", type=str, diff --git a/ppdet/engine/tracker.py b/ppdet/engine/tracker.py index 0d4f98b68c5f9e29e26a72fd319f27e7e2ffc846..91497d4c6844e76e65af3645a7641a602b48a10a 100644 --- a/ppdet/engine/tracker.py +++ b/ppdet/engine/tracker.py @@ -166,6 +166,7 @@ class Tracker(object): save_dir=None, show_image=False, frame_rate=30, + scaled=False, det_file='', draw_threshold=0): if save_dir: @@ -211,8 +212,12 @@ class Tracker(object): else: outs = self.model.detector(data) if outs['bbox_num'] > 0: - pred_bboxes = scale_coords(outs['bbox'][:, 2:], input_shape, - im_shape, scale_factor) + if not scaled: + pred_bboxes = scale_coords(outs['bbox'][:, 2:], + input_shape, im_shape, + scale_factor) + else: + pred_bboxes = outs['bbox'][:, 2:] pred_scores = outs['bbox'][:, 1:2] else: pred_bboxes = [] @@ -270,6 +275,7 @@ class Tracker(object): save_images=False, save_videos=False, show_image=False, + scaled=False, det_results_dir=''): if not os.path.exists(output_dir): os.makedirs(output_dir) result_root = os.path.join(output_dir, 'mot_results') @@ -318,6 +324,7 @@ class Tracker(object): save_dir=save_dir, show_image=show_image, frame_rate=frame_rate, + scaled=scaled, det_file=os.path.join(det_results_dir, '{}.txt'.format(seq))) else: @@ -382,6 +389,7 @@ class Tracker(object): save_images=False, save_videos=True, show_image=False, + scaled=False, det_results_dir='', draw_threshold=0.5): assert video_file is not None or image_dir is not None, \ @@ -438,6 +446,7 @@ class Tracker(object): save_dir=save_dir, show_image=show_image, frame_rate=frame_rate, + scaled=scaled, det_file=os.path.join(det_results_dir, '{}.txt'.format(seq)), draw_threshold=draw_threshold) diff --git a/tools/eval_mot.py b/tools/eval_mot.py index fb8e296124dd9ef5470216e4d00161446b8349be..14e15ebb2b41d3a1fb54f4dd597a72e29b28bf70 100644 --- a/tools/eval_mot.py +++ b/tools/eval_mot.py @@ -62,6 +62,12 @@ def parse_args(): '--show_image', action='store_true', help='Show tracking results (image).') + parser.add_argument( + '--scaled', + type=bool, + default=False, + help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 " + "True in general detector.") args = parser.parse_args() return args @@ -95,6 +101,7 @@ def run(FLAGS, cfg): save_images=FLAGS.save_images, save_videos=FLAGS.save_videos, show_image=FLAGS.show_image, + scaled=FLAGS.scaled, det_results_dir=FLAGS.det_results_dir) diff --git a/tools/infer_mot.py b/tools/infer_mot.py index 1131bcd0751560a54d9fd8a69ea8e5edd4429fe8..9054f1575df1b13c4460d5bbe1b1b556b492d0fe 100644 --- a/tools/infer_mot.py +++ b/tools/infer_mot.py @@ -74,6 +74,12 @@ def parse_args(): '--show_image', action='store_true', help='Show tracking results (image).') + parser.add_argument( + '--scaled', + type=bool, + default=False, + help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 " + "True in general detector.") parser.add_argument( "--draw_threshold", type=float, @@ -107,6 +113,7 @@ def run(FLAGS, cfg): save_images=FLAGS.save_images, save_videos=FLAGS.save_videos, show_image=FLAGS.show_image, + scaled=FLAGS.scaled, det_results_dir=FLAGS.det_results_dir, draw_threshold=FLAGS.draw_threshold)