未验证 提交 dbfc8c91 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] fix deepsort yolov3 infer and deploy (#4277)

* add general yolov3 for deepsort

* refine format

* fix no target infer

* fix readme and conflict
上级 6bf1b443
......@@ -55,7 +55,7 @@ If you use a stronger detection model, you can get better results. Each txt is t
- `width,height` is the pixel width and height
- `conf` is the object score with default value `1` (the results had been filtered out according to the detection score threshold)
- 2.Load the detection model and the ReID model at the same time. Here, the JDE version of YOLOv3 is selected. For more detail of configuration, see `configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml`.
- 2. Load the detection model and the ReID model at the same time. Here, the JDE version of YOLOv3 is selected. For more detail of configuration, see `configs/mot/deepsort/_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml`. Load other general detection model, you can refer to `configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml`.
## Getting Start
......@@ -65,40 +65,60 @@ If you use a stronger detection model, you can get better results. Each txt is t
# Load the result file and ReID model to get the tracking result
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml --det_results_dir {your detection results}
# Load the detection model and ReID model to get the tracking results
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml
# Load JDE YOLOv3 detector and ReID model to get the tracking results
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml
# or Load genernal YOLOv3 detector and ReID model to get the tracking results
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --scaled=True
```
**Notes:**
JDE YOLOv3 pedestrian detector is trained with the same MOT dataset as JDE and FairMOT. In addition, the biggest difference between this model and general YOLOv3 model is that JDEBBoxPostProcess post-processing, and the output coordinates are not scaled back to the original image.
General YOLOv3 pedestrian detector is not trained on MOT dataset, so the performance is lower. But the output coordinates are scaled back to the original image.
`--scaled` means whether the coords after detector outputs are scaled back to the original image, False in JDE YOLOv3, True in general detector.
### 2. Inference
Inference a vidoe on single GPU with following command:
```bash
# inference on video and save a video
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --save_videos
# load JDE YOLOv3 pedestrian detector and ReID model to get tracking results
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --save_videos
# or load general YOLOv3 pedestrian detector and ReID model to get tracking results
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --scaled=True --save_videos
```
**Notes:**
Please make sure that [ffmpeg](https://ffmpeg.org/ffmpeg.html) is installed first, on Linux(Ubuntu) platform you can directly install it by the following command:`apt-get update && apt-get install -y ffmpeg`.
`--scaled` means whether the coords after detector outputs are scaled back to the original image, False in JDE YOLOv3, True in general detector.
### 3. Export model
```bash
1.export detection model
# 1.export detection model
# export JDE YOLOv3 pedestrian detector
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams
2.export ReID model
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
or
# or export general YOLOv3 pedestrian detector
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/pedestrian/pedestrian_yolov3_darknet.yml -o weights=https://paddledet.bj.bcebos.com/models/pedestrian_yolov3_darknet.pdparams
# 2. export ReID model
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
```
### 4. Using exported model for python inference
```bash
python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
# using exported JDE YOLOv3 pedestrian detector
python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
# or using exported general YOLOv3 pedestrian detector
python deploy/python/mot_sde_infer.py --model_dir=output_inference/pedestrian_yolov3_darknet/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --scaled=True --save_mot_txts
```
**Notes:**
The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts` to save the txt result file, or `--save_images` to save the visualization images.
The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts`(save a txt for every video) or `--save_mot_txt_per_img`(save a txt for every image) to save the txt result file, or `--save_images` to save the visualization images.
`--scaled` means whether the coords after detector outputs are scaled back to the original image, False in JDE YOLOv3, True in general detector.
## Citations
```
......
......@@ -56,7 +56,7 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip
- `width,height`是真实的像素宽高
- `conf`是目标得分设置为`1`(已经按检测的得分阈值筛选出的检测结果)
- 第2种方式是同时加载检测模型和ReID模型,此处选用JDE版本的YOLOv3,具体配置见`configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml`
- 第2种方式是同时加载检测模型和ReID模型,此处选用JDE版本的YOLOv3,具体配置见`configs/mot/deepsort/_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml`。加载其他通用检测模型可参照`configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml`进行修改。
## 快速开始
......@@ -66,42 +66,61 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip
# 加载检测结果文件和ReID模型,得到跟踪结果
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml --det_results_dir {your detection results}
# 加载检测模型和ReID模型,得到跟踪结果
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml
# 加载JDE YOLOv3行人检测模型和ReID模型,得到跟踪结果
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml
# 或者加载普通YOLOv3行人检测模型和ReID模型,得到跟踪结果
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --scaled=True
```
**注意:**
JDE YOLOv3行人检测模型是和JDE和FairMOT使用同样的MOT数据集训练的,这个模型与普通YOLOv3模型最大的区别是使用了JDEBBoxPostProcess后处理,结果输出坐标没有缩放回原图。
普通YOLOv3行人检测模型不是用MOT数据集训练的,所以精度效果更低, 其模型输出坐标是缩放回原图的。
`--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
### 2. 预测
使用单个GPU通过如下命令预测一个视频,并保存为视频
```bash
# 加载检测模型和ReID模型,得到跟踪结果
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --save_videos
# 加载JDE YOLOv3行人检测模型和ReID模型,并保存为视频
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_jde_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --save_videos
# 或者加载普通YOLOv3行人检测模型和ReID模型,并保存为视频
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml --video_file={your video name}.mp4 --scaled=True --save_videos
```
**注意:**
请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`
`--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
### 3. 导出预测模型
```bash
1.先导出检测模型
# 1.先导出检测模型
# 导出JDE YOLOv3行人检测模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams
2.再导出ReID模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
# 或导出普通YOLOv3行人检测模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/pedestrian/pedestrian_yolov3_darknet.yml -o weights=https://paddledet.bj.bcebos.com/models/pedestrian_yolov3_darknet.pdparams
# 2.再导出ReID模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
```
### 4. 用导出的模型基于Python去预测
```bash
python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
# 用导出JDE YOLOv3行人检测模型
python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
# 或用导出的普通yolov3行人检测模型
python deploy/python/mot_sde_infer.py --model_dir=output_inference/pedestrian_yolov3_darknet/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --scaled=True --save_mot_txts
```
**注意:**
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_mot_txt_per_img`(对每张图片保存一个txt)表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
`--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
## 引用
```
......
architecture: DeepSORT
pretrain_weights: None
DeepSORT:
detector: YOLOv3 # JDE version
reid: PCBPyramid
tracker: DeepSORTTracker
PCBPyramid:
num_conv_out_channels: 128
num_classes: 751
DeepSORTTracker:
budget: 100
max_age: 70
n_init: 3
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
motion: KalmanFilter
# JDE version YOLOv3 detector for MOT dataset.
# The most obvious difference is JDEBBoxPostProcess and the bboxes coordinates
# output are not scaled to the original image.
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: JDEBBoxPostProcess
DarkNet:
depth: 53
return_idx: [2, 3, 4]
freeze_norm: True
YOLOv3FPN:
freeze_norm: True
YOLOv3Head:
anchors: [[128,384], [180,540], [256,640], [512,640],
[32,96], [45,135], [64,192], [90,271],
[8,24], [11,34], [16,48], [23,68]]
anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
loss: JDEDetectionLoss
JDEBBoxPostProcess:
decode:
name: JDEBox
conf_thresh: 0.3
downsample_ratio: 32
nms:
name: MultiClassNMS
keep_top_k: 500
score_threshold: 0.01
nms_threshold: 0.5
nms_top_k: 2000
normalized: true
return_idx: false
# DeepSORT does not need to train on MOT dataset, only used for evaluation.
# MOT dataset needs to be trained on the detector(like YOLOv3) only using bboxes.
# And gt IDs don't need to be trained.
EvalMOTReader:
sample_transforms:
- Decode: {}
......
......@@ -2,55 +2,57 @@ architecture: DeepSORT
pretrain_weights: None
DeepSORT:
detector: YOLOv3 # JDE version
detector: YOLOv3 # General version
reid: PCBPyramid
tracker: DeepSORTTracker
# JDE version for MOT dataset
PCBPyramid:
num_conv_out_channels: 128
num_classes: 751
DeepSORTTracker:
budget: 100
max_age: 70
n_init: 3
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
motion: KalmanFilter
# General version YOLOv3
# Using BBoxPostProcess and the bboxes output are scaled to the original image.
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: JDEBBoxPostProcess
post_process: BBoxPostProcess
norm_type: sync_bn
DarkNet:
depth: 53
return_idx: [2, 3, 4]
freeze_norm: True
YOLOv3FPN:
freeze_norm: True
# use default config
# YOLOv3FPN:
YOLOv3Head:
anchors: [[128,384], [180,540], [256,640], [512,640],
[32,96], [45,135], [64,192], [90,271],
[8,24], [11,34], [16,48], [23,68]]
anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
loss: JDEDetectionLoss
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
JDEBBoxPostProcess:
BBoxPostProcess:
decode:
name: JDEBox
conf_thresh: 0.3
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
nms:
name: MultiClassNMS
keep_top_k: 500
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.5
nms_top_k: 2000
normalized: true
return_idx: false
PCBPyramid:
num_conv_out_channels: 128
num_classes: 751
DeepSORTTracker:
budget: 100
max_age: 70
n_init: 3
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
motion: KalmanFilter
nms_threshold: 0.45
nms_top_k: 1000
_BASE_: [
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/deepsort_jde_yolov3_darknet53_pcb_pyramid_r101.yml',
'_base_/deepsort_reader_1088x608.yml',
]
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: MOT16/images/train
keep_ori_im: True # set as True in DeepSORT
det_weights: https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
DeepSORT:
detector: YOLOv3
reid: PCBPyramid
tracker: DeepSORTTracker
# JDE version YOLOv3 detector for MOT dataset.
# The most obvious difference is JDEBBoxPostProcess and the bboxes coordinates
# output are not scaled to the original image.
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: JDEBBoxPostProcess
......@@ -11,7 +11,7 @@ EvalMOTDataset:
data_root: MOT16/images/train
keep_ori_im: True # set as True in DeepSORT
det_weights: https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams
det_weights: https://paddledet.bj.bcebos.com/models/pedestrian_yolov3_darknet.pdparams
reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
DeepSORT:
......@@ -19,9 +19,10 @@ DeepSORT:
reid: PCBPyramid
tracker: DeepSORTTracker
# JDE version for MOT dataset
# General version YOLOv3
# Using BBoxPostProcess and the bboxes output are scaled to the original image.
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: JDEBBoxPostProcess
post_process: BBoxPostProcess
......@@ -92,7 +92,9 @@ class JDE_Detector(Detector):
def postprocess(self, pred_dets, pred_embs, threshold):
online_targets = self.tracker.update(pred_dets, pred_embs)
if online_targets == []:
return [pred_dets[0][:4]], [pred_dets[0][4]], [1]
# First few frames, the model may have no tracking results but have
# detection results,use the detection results instead, and set id -1.
return [pred_dets[0][:4]], [pred_dets[0][4]], [-1]
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
......@@ -162,8 +164,6 @@ def write_mot_results(filename, results, data_type='mot'):
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
......@@ -254,6 +254,15 @@ def predict_video(detector, camera_id):
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im)
if FLAGS.save_mot_txt_per_img:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
result_filename = os.path.join(save_dir,
'{:05d}.txt'.format(frame_id))
write_mot_results(result_filename, [results[-1]])
frame_id += 1
print('detect frame:%d' % (frame_id))
if camera_id != -1:
......
......@@ -135,20 +135,29 @@ class SDE_Detector(Detector):
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
def postprocess(self, boxes, input_shape, im_shape, scale_factor,
threshold):
pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape,
def postprocess(self, boxes, input_shape, im_shape, scale_factor, threshold,
scaled):
if not scaled:
# postprocess output of jde yolov3 detector
pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape,
scale_factor)
pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor)
pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape, scale_factor)
else:
# postprocess output of general detector
pred_bboxes = boxes[:, 2:]
pred_scores = boxes[:, 1:2]
keep_mask = pred_scores[:, 0] >= threshold
return pred_bboxes[keep_mask], pred_scores[keep_mask]
def predict(self, image, threshold=0.5, warmup=0, repeats=1):
def predict(self, image, scaled, threshold=0.5, warmup=0, repeats=1):
'''
Args:
image (np.ndarray): image numpy data
threshold (float): threshold of predicted box' score
scaled (bool): whether the coords after detector outputs are scaled,
default False in jde yolov3, set True in general detector.
Returns:
pred_bboxes, pred_scores (np.ndarray)
'''
......@@ -181,7 +190,7 @@ class SDE_Detector(Detector):
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
pred_bboxes, pred_scores = self.postprocess(
boxes, input_shape, im_shape, scale_factor, threshold)
boxes, input_shape, im_shape, scale_factor, threshold, scaled)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return pred_bboxes, pred_scores
......@@ -302,14 +311,14 @@ def predict_image(detector, reid_model, image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
pred_bboxes, pred_scores = detector.predict(
[frame], FLAGS.threshold, warmup=10, repeats=10)
[frame], FLAGS.scaled, FLAGS.threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file))
else:
pred_bboxes, pred_scores = detector.predict([frame],
pred_bboxes, pred_scores = detector.predict([frame], FLAGS.scaled,
FLAGS.threshold)
# process
......@@ -319,7 +328,8 @@ def predict_image(detector, reid_model, image_list):
axis=1)
crops, pred_scores = reid_model.get_crops(
pred_bboxes, frame, pred_scores, w=64, h=192)
if len(crops) == 0:
continue
if FLAGS.run_benchmark:
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, bbox_tlwh, pred_scores, warmup=10, repeats=10)
......@@ -366,7 +376,8 @@ def predict_video(detector, reid_model, camera_id):
if not ret:
break
timer.tic()
pred_bboxes, pred_scores = detector.predict([frame], FLAGS.threshold)
pred_bboxes, pred_scores = detector.predict([frame], FLAGS.scaled,
FLAGS.threshold)
timer.toc()
bbox_tlwh = np.concatenate(
(pred_bboxes[:, 0:2],
......@@ -374,7 +385,8 @@ def predict_video(detector, reid_model, camera_id):
axis=1)
crops, pred_scores = reid_model.get_crops(
pred_bboxes, frame, pred_scores, w=64, h=192)
if len(crops) == 0:
continue
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, bbox_tlwh, pred_scores)
......@@ -395,6 +407,23 @@ def predict_video(detector, reid_model, camera_id):
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im)
if FLAGS.save_mot_txt_per_img:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
result_filename = os.path.join(save_dir,
'{:05d}.txt'.format(frame_id))
# First few frames, the model may have no tracking results but have
# detection results,use the detection results instead, and set id -1.
if results[-1][2] == []:
tlwhs = [tlwh for tlwh in bbox_tlwh]
scores = [score[0] for score in pred_scores]
result = (frame_id + 1, tlwhs, scores, [-1] * len(tlwhs))
else:
result = results[-1]
write_mot_results(result_filename, [result])
frame_id += 1
print('detect frame:%d' % (frame_id))
if camera_id != -1:
......
......@@ -108,6 +108,16 @@ def argsparser():
'--save_mot_txts',
action='store_true',
help='Save tracking results (txt).')
parser.add_argument(
'--save_mot_txt_per_img',
action='store_true',
help='Save tracking results (txt) for each image.')
parser.add_argument(
'--scaled',
type=bool,
default=False,
help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.")
parser.add_argument(
"--reid_model_dir",
type=str,
......
......@@ -166,6 +166,7 @@ class Tracker(object):
save_dir=None,
show_image=False,
frame_rate=30,
scaled=False,
det_file='',
draw_threshold=0):
if save_dir:
......@@ -211,8 +212,12 @@ class Tracker(object):
else:
outs = self.model.detector(data)
if outs['bbox_num'] > 0:
pred_bboxes = scale_coords(outs['bbox'][:, 2:], input_shape,
im_shape, scale_factor)
if not scaled:
pred_bboxes = scale_coords(outs['bbox'][:, 2:],
input_shape, im_shape,
scale_factor)
else:
pred_bboxes = outs['bbox'][:, 2:]
pred_scores = outs['bbox'][:, 1:2]
else:
pred_bboxes = []
......@@ -270,6 +275,7 @@ class Tracker(object):
save_images=False,
save_videos=False,
show_image=False,
scaled=False,
det_results_dir=''):
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
......@@ -318,6 +324,7 @@ class Tracker(object):
save_dir=save_dir,
show_image=show_image,
frame_rate=frame_rate,
scaled=scaled,
det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq)))
else:
......@@ -382,6 +389,7 @@ class Tracker(object):
save_images=False,
save_videos=True,
show_image=False,
scaled=False,
det_results_dir='',
draw_threshold=0.5):
assert video_file is not None or image_dir is not None, \
......@@ -438,6 +446,7 @@ class Tracker(object):
save_dir=save_dir,
show_image=show_image,
frame_rate=frame_rate,
scaled=scaled,
det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq)),
draw_threshold=draw_threshold)
......
......@@ -62,6 +62,12 @@ def parse_args():
'--show_image',
action='store_true',
help='Show tracking results (image).')
parser.add_argument(
'--scaled',
type=bool,
default=False,
help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.")
args = parser.parse_args()
return args
......@@ -95,6 +101,7 @@ def run(FLAGS, cfg):
save_images=FLAGS.save_images,
save_videos=FLAGS.save_videos,
show_image=FLAGS.show_image,
scaled=FLAGS.scaled,
det_results_dir=FLAGS.det_results_dir)
......
......@@ -74,6 +74,12 @@ def parse_args():
'--show_image',
action='store_true',
help='Show tracking results (image).')
parser.add_argument(
'--scaled',
type=bool,
default=False,
help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.")
parser.add_argument(
"--draw_threshold",
type=float,
......@@ -107,6 +113,7 @@ def run(FLAGS, cfg):
save_images=FLAGS.save_images,
save_videos=FLAGS.save_videos,
show_image=FLAGS.show_image,
scaled=FLAGS.scaled,
det_results_dir=FLAGS.det_results_dir,
draw_threshold=FLAGS.draw_threshold)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册