未验证 提交 3fca7404 编写于 作者: F Feng Ni 提交者: GitHub

[cherry-pick] Add pptracking api (#4861)

* fix plot_tracking_dict, test=document_fix

* add mot jde api for pptracking, test=document_fix
上级 f00a4c00
......@@ -111,6 +111,7 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_
- DeepSORT算法不支持多类别跟踪,只支持单类别跟踪,且ReID模型最好是与检测模型同一类别的物体训练过的,比如行人跟踪最好使用行人ReID模型,车辆跟踪最好使用车辆ReID模型。
## 3. 跨境跟踪模型的导出和预测
### 3.1 导出预测模型
Step 1:下载导出的检测模型
......@@ -129,11 +130,15 @@ tar -xvf deepsort_pplcnet_vehicle.tar
### 3.2 用导出的模型基于Python去做跨镜头跟踪
```bash
# 下载demo测试视频
wget https://paddledet.bj.bcebos.com/data/mot/demo/mtmct-demo.tar
tar -xvf mtmct-demo.tar
# 用导出的PicoDet车辆检测模型和PPLCNet车辆ReID模型
python deploy/pptracking/python/mot_sde_infer.py --model_dir=picodet_l_640_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir={your mtmct scene video folder} --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --threshold=0.5 --save_mot_txts --save_images
python deploy/pptracking/python/mot_sde_infer.py --model_dir=picodet_l_640_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir=mtmct-demo --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --threshold=0.5 --save_mot_txts --save_images
# 用导出的PP-YOLOv2车辆检测模型和PPLCNet车辆ReID模型
python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir={your mtmct scene video folder} --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --threshold=0.5 --save_mot_txts --save_images
python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir=mtmct-demo --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --threshold=0.5 --save_mot_txts --save_images
```
**注意:**
......@@ -146,7 +151,78 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_
- `--mtmct_cfg`是MTMCT预测的某个场景的配置文件,里面包含该一些trick操作的开关和该场景摄像头相关设置的文件路径,用户可以自行更改相关路径以及设置某些操作是否启用。
## 参数说明:
## 4. API调用方式:
### 4.1 FairMOT模型API调用
```
import mot_jde_infer
# 1.model config and weights
model_dir = 'fairmot_hrnetv2_w18_dlafpn_30e_576x320/'
# 2.inference data
video_file = 'test.mp4'
image_dir = None
# 3.other settings
device = 'CPU' # device should be CPU, GPU or XPU
threshold = 0.3
output_dir = 'output'
# mot predict
mot_jde_infer.predict_naive(model_dir, video_file, image_dir, device, threshold, output_dir)
```
**注意:**
- 以上代码必须进入目录`PaddleDetection/deploy/pptracking/python`下执行。
- 支持对视频和图片文件夹进行预测,不支持单张图的预测,`video_file``image_dir`不能同时为None,推荐使用`video_file`,而`image_dir`需直接存放命名顺序规范的图片。
- 默认会保存跟踪结果可视化后的图片和视频,以及跟踪结果txt文件,默认不会进行轨迹可视化和流量统计。
### 4.2 DeepSORT模型API调用
```
import mot_sde_infer
# 1.model config and weights
model_dir = 'ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/'
reid_model_dir = 'deepsort_pplcnet_vehicle/'
# 2.inference data
video_file = 'test.mp4'
image_dir = None
# 3.other settings
scaled = True # set False only when use JDE YOLOv3
device = 'CPU' # device should be CPU, GPU or XPU
threshold = 0.3
output_dir = 'output'
# 4. MTMCT settings, default None
mtmct_dir = None
mtmct_cfg = None
# mot predict
mot_sde_infer.predict_naive(model_dir,
reid_model_dir,
video_file,
image_dir,
mtmct_dir,
mtmct_cfg,
scaled,
device,
threshold,
output_dir)
```
**注意:**
- 以上代码必须进入目录`PaddleDetection/deploy/pptracking/python`下执行。
- 支持对视频和图片文件夹进行预测,不支持单张图的预测,`video_file``image_dir``--mtmct_dir`不能同时为None,推荐使用`video_file`,而`image_dir`需直接存放命名顺序规范的图片,`--mtmct_dir`不为None表示是进行的MTMCT跨镜头跟踪任务。
- 默认会保存跟踪结果可视化后的图片和视频,以及跟踪结果txt文件,默认不会进行轨迹可视化和流量统计。
- `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
- `--mtmct_dir`是MTMCT预测的某个场景的文件夹名字,里面包含该场景不同摄像头拍摄视频的图片文件夹,其数量至少为两个。
- `--mtmct_cfg`是MTMCT预测的某个场景的配置文件,里面包含该一些trick操作的开关和该场景摄像头相关设置的文件路径,用户可以自行更改相关路径以及设置某些操作是否启用。
- 开启MTMCT预测必须将`video_file``image_dir`同时设置为None,且`--mtmct_dir``--mtmct_cfg`都必须不为None。
## 5. 参数说明:
| 参数 | 是否必须|含义 |
|-------|-------|----------|
......
......@@ -167,7 +167,12 @@ class JDE_Detector(Detector):
return online_tlwhs, online_scores, online_ids
def predict_image(detector, image_list):
def predict_image(detector,
image_list,
threshold,
output_dir,
save_images=True,
run_benchmark=False):
results = []
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
......@@ -176,13 +181,11 @@ def predict_image(detector, image_list):
image_list.sort()
for frame_id, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
if run_benchmark:
# warmup
detector.predict(
[img_file], FLAGS.threshold, repeats=10, add_timer=False)
detector.predict([img_file], threshold, repeats=10, add_timer=False)
# run benchmark
detector.predict(
[img_file], FLAGS.threshold, repeats=10, add_timer=True)
detector.predict([img_file], threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
......@@ -190,7 +193,7 @@ def predict_image(detector, image_list):
print('Test iter {}, file name:{}'.format(frame_id, img_file))
else:
online_tlwhs, online_scores, online_ids = detector.predict(
[img_file], FLAGS.threshold)
[img_file], threshold)
online_im = plot_tracking_dict(
frame,
num_classes,
......@@ -199,22 +202,32 @@ def predict_image(detector, image_list):
online_scores,
frame_id=frame_id,
ids2names=ids2names)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
if save_images:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
img_name = os.path.split(img_file)[-1]
out_path = os.path.join(FLAGS.output_dir, img_name)
out_path = os.path.join(output_dir, img_name)
cv2.imwrite(out_path, online_im)
print("save result to: " + out_path)
def predict_video(detector, camera_id):
def predict_video(detector,
video_file,
threshold,
output_dir,
save_images=True,
save_mot_txts=True,
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False,
camera_id=-1):
video_name = 'mot_output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
capture = cv2.VideoCapture(video_file)
video_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
......@@ -222,10 +235,10 @@ def predict_video(detector, camera_id):
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
out_path = os.path.join(output_dir, video_name)
if not save_images:
video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
......@@ -238,7 +251,7 @@ def predict_video(detector, camera_id):
center_traj = None
entrance = None
records = None
if FLAGS.draw_center_traj:
if draw_center_traj:
center_traj = [{} for i in range(num_classes)]
if num_classes == 1:
......@@ -257,8 +270,8 @@ def predict_video(detector, camera_id):
if not ret:
break
timer.tic()
online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold)
online_tlwhs, online_scores, online_ids = detector.predict([frame],
threshold)
timer.toc()
for cls_id in range(num_classes):
......@@ -271,9 +284,9 @@ def predict_video(detector, camera_id):
result = (frame_id + 1, online_tlwhs[0], online_scores[0],
online_ids[0])
statistic = flow_statistic(
result, FLAGS.secs_interval, FLAGS.do_entrance_counting,
video_fps, entrance, id_set, interval_id_set, in_id_list,
out_id_list, prev_center, records, data_type, num_classes)
result, secs_interval, do_entrance_counting, video_fps,
entrance, id_set, interval_id_set, in_id_list, out_id_list,
prev_center, records, data_type, num_classes)
id_set = statistic['id_set']
interval_id_set = statistic['interval_id_set']
in_id_list = statistic['in_id_list']
......@@ -281,7 +294,7 @@ def predict_video(detector, camera_id):
prev_center = statistic['prev_center']
records = statistic['records']
elif num_classes > 1 and FLAGS.do_entrance_counting:
elif num_classes > 1 and do_entrance_counting:
raise NotImplementedError(
'Multi-class flow counting is not implemented now!')
im = plot_tracking_dict(
......@@ -293,13 +306,13 @@ def predict_video(detector, camera_id):
frame_id=frame_id,
fps=fps,
ids2names=ids2names,
do_entrance_counting=FLAGS.do_entrance_counting,
do_entrance_counting=do_entrance_counting,
entrance=entrance,
records=records,
center_traj=center_traj)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if save_images:
save_dir = os.path.join(output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
......@@ -313,24 +326,23 @@ def predict_video(detector, camera_id):
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
if save_mot_txts:
result_filename = os.path.join(output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes)
if num_classes == 1:
result_filename = os.path.join(
FLAGS.output_dir,
video_name.split('.')[-2] + '_flow_statistic.txt')
output_dir, video_name.split('.')[-2] + '_flow_statistic.txt')
f = open(result_filename, 'w')
for line in records:
f.write(line)
print('Flow statistic save in {}'.format(result_filename))
f.close()
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if save_images:
save_dir = os.path.join(output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
......@@ -339,6 +351,36 @@ def predict_video(detector, camera_id):
writer.release()
def predict_naive(model_dir,
video_file,
image_dir,
device='gpu',
threshold=0.5,
output_dir='output'):
pred_config = PredictConfig(model_dir)
detector = JDE_Detector(pred_config, model_dir, device=device.upper())
if video_file is not None:
predict_video(
detector,
video_file,
threshold=threshold,
output_dir=output_dir,
save_images=True,
save_mot_txts=True,
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False)
else:
img_list = get_test_images(image_dir, infer_img=None)
predict_image(
detector,
img_list,
threshold=threshold,
output_dir=output_dir,
save_images=True)
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector = JDE_Detector(
......@@ -355,11 +397,27 @@ def main():
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
predict_video(
detector,
FLAGS.video_file,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts,
draw_center_traj=FLAGS.draw_center_traj,
secs_interval=FLAGS.secs_interval,
do_entrance_counting=FLAGS.do_entrance_counting,
camera_id=FLAGS.camera_id)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list)
predict_image(
detector,
img_list,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
save_images=FLAGS.save_images,
run_benchmark=FLAGS.run_benchmark)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
......
......@@ -316,6 +316,8 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
np_score_list, np_boxes_list = [], []
# model prediction
for i in range(repeats):
self.predictor.run()
......@@ -549,26 +551,33 @@ class SDE_ReID(object):
return tracking_outs
def predict_image(detector, reid_model, image_list):
def predict_image(detector,
reid_model,
image_list,
threshold,
output_dir,
scaled=True,
save_images=True,
run_benchmark=False):
image_list.sort()
for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
ori_image_shape = list(frame.shape[:2])
if FLAGS.run_benchmark:
if run_benchmark:
# warmup
pred_dets, pred_xyxys = detector.predict(
[img_file],
ori_image_shape,
FLAGS.threshold,
FLAGS.scaled,
threshold,
scaled,
repeats=10,
add_timer=False)
# run benchmark
pred_dets, pred_xyxys = detector.predict(
[img_file],
ori_image_shape,
FLAGS.threshold,
FLAGS.scaled,
threshold,
scaled,
repeats=10,
add_timer=True)
......@@ -579,7 +588,7 @@ def predict_image(detector, reid_model, image_list):
print('Test iter {}, file name:{}'.format(i, img_file))
else:
pred_dets, pred_xyxys = detector.predict(
[img_file], ori_image_shape, FLAGS.threshold, FLAGS.scaled)
[img_file], ori_image_shape, threshold, scaled)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'.
......@@ -589,7 +598,7 @@ def predict_image(detector, reid_model, image_list):
# reid process
crops = reid_model.get_crops(pred_xyxys, frame)
if FLAGS.run_benchmark:
if run_benchmark:
# warmup
tracking_outs = reid_model.predict(
crops, pred_dets, repeats=10, add_timer=False)
......@@ -607,22 +616,34 @@ def predict_image(detector, reid_model, image_list):
online_im = plot_tracking(
frame, online_tlwhs, online_ids, online_scores, frame_id=i)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
if save_images:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
img_name = os.path.split(img_file)[-1]
out_path = os.path.join(FLAGS.output_dir, img_name)
out_path = os.path.join(output_dir, img_name)
cv2.imwrite(out_path, online_im)
print("save result to: " + out_path)
def predict_video(detector, reid_model, camera_id):
def predict_video(detector,
reid_model,
video_file,
scaled,
threshold,
output_dir,
save_images=True,
save_mot_txts=True,
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False,
camera_id=-1):
video_name = 'mot_output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'mot_output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
capture = cv2.VideoCapture(video_file)
video_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
......@@ -630,10 +651,10 @@ def predict_video(detector, reid_model, camera_id):
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
out_path = os.path.join(output_dir, video_name)
if not save_images:
video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
......@@ -656,7 +677,7 @@ def predict_video(detector, reid_model, camera_id):
timer.tic()
ori_image_shape = list(frame.shape[:2])
pred_dets, pred_xyxys = detector.predict([frame], ori_image_shape,
FLAGS.threshold, FLAGS.scaled)
threshold, scaled)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'.
......@@ -677,9 +698,9 @@ def predict_video(detector, reid_model, camera_id):
# NOTE: just implement flow statistic for one class
result = (frame_id + 1, online_tlwhs, online_scores, online_ids)
statistic = flow_statistic(
result, FLAGS.secs_interval, FLAGS.do_entrance_counting,
video_fps, entrance, id_set, interval_id_set, in_id_list,
out_id_list, prev_center, records)
result, secs_interval, do_entrance_counting, video_fps,
entrance, id_set, interval_id_set, in_id_list, out_id_list,
prev_center, records)
id_set = statistic['id_set']
interval_id_set = statistic['interval_id_set']
in_id_list = statistic['in_id_list']
......@@ -697,11 +718,11 @@ def predict_video(detector, reid_model, camera_id):
online_scores,
frame_id=frame_id,
fps=fps,
do_entrance_counting=FLAGS.do_entrance_counting,
do_entrance_counting=do_entrance_counting,
entrance=entrance)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if save_images:
save_dir = os.path.join(output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
......@@ -717,21 +738,21 @@ def predict_video(detector, reid_model, camera_id):
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
if save_mot_txts:
result_filename = os.path.join(output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
result_filename = os.path.join(
FLAGS.output_dir, video_name.split('.')[-2] + '_flow_statistic.txt')
output_dir, video_name.split('.')[-2] + '_flow_statistic.txt')
f = open(result_filename, 'w')
for line in records:
f.write(line)
print('Flow statistic save in {}'.format(result_filename))
f.close()
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if save_images:
save_dir = os.path.join(output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
......@@ -740,8 +761,16 @@ def predict_video(detector, reid_model, camera_id):
writer.release()
def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
fpath = os.path.join(FLAGS.mtmct_dir, seq_name)
def predict_mtmct_seq(detector,
reid_model,
mtmct_dir,
seq_name,
scaled,
threshold,
output_dir,
save_images=True,
save_mot_txts=True):
fpath = os.path.join(mtmct_dir, seq_name)
if os.path.exists(os.path.join(fpath, 'img1')):
fpath = os.path.join(fpath, 'img1')
......@@ -756,13 +785,13 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
len(image_list), seq_name))
for frame_id, img_file in enumerate(image_list):
if frame_id % 40 == 0:
if frame_id % 10 == 0:
print('Processing frame {} of seq {}.'.format(frame_id, seq_name))
frame = cv2.imread(os.path.join(fpath, img_file))
ori_image_shape = list(frame.shape[:2])
frame_path = os.path.join(fpath, img_file)
pred_dets, pred_xyxys = detector.predict([frame_path], ori_image_shape,
FLAGS.threshold, FLAGS.scaled)
threshold, scaled)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'.
......@@ -791,21 +820,29 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
if FLAGS.save_images:
if save_images:
save_dir = os.path.join(output_dir, seq_name)
if not os.path.exists(save_dir): os.makedirs(save_dir)
img_name = os.path.split(img_file)[-1]
out_path = os.path.join(save_dir, img_name)
cv2.imwrite(out_path, online_im)
if FLAGS.save_mot_txts:
if save_mot_txts:
result_filename = os.path.join(output_dir, seq_name + '.txt')
write_mot_results(result_filename, results)
return mot_features_dict
def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
def predict_mtmct(detector,
reid_model,
mtmct_dir,
mtmct_cfg,
scaled,
threshold,
output_dir,
save_images=True,
save_mot_txts=True):
MTMCT = mtmct_cfg['MTMCT']
assert MTMCT == True, 'predict_mtmct should be used for MTMCT.'
......@@ -832,7 +869,6 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
mot_list_breaks = []
cid_tid_dict = dict()
output_dir = FLAGS.output_dir
if not os.path.exists(output_dir): os.makedirs(output_dir)
seqs = os.listdir(mtmct_dir)
......@@ -852,8 +888,9 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
print('{} is not a image folder.'.format(fpath))
continue
mot_features_dict = predict_mtmct_seq(detector, reid_model, seq,
output_dir)
mot_features_dict = predict_mtmct_seq(
detector, reid_model, mtmct_dir, seq, scaled, threshold, output_dir,
save_images, save_mot_txts)
cid = int(re.sub('[a-z,A-Z]', "", seq))
tid_data, mot_list_break = trajectory_fusion(
......@@ -911,6 +948,62 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
print_mtmct_result(data_root_gt, pred_mtmct_file)
def predict_naive(model_dir,
reid_model_dir,
video_file,
image_dir,
mtmct_dir=None,
mtmct_cfg=None,
scaled=True,
device='gpu',
threshold=0.5,
output_dir='output'):
pred_config = PredictConfig(model_dir)
detector_func = 'SDE_Detector'
if pred_config.arch == 'PicoDet':
detector_func = 'SDE_DetectorPicoDet'
detector = eval(detector_func)(pred_config, model_dir, device=device)
pred_config = PredictConfig(reid_model_dir)
reid_model = SDE_ReID(pred_config, reid_model_dir, device=device)
if video_file is not None:
predict_video(
detector,
reid_model,
video_file,
scaled=scaled,
threshold=threshold,
output_dir=output_dir,
save_images=True,
save_mot_txts=True,
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False)
elif mtmct_dir is not None:
with open(mtmct_cfg) as f:
mtmct_cfg_file = yaml.safe_load(f)
predict_mtmct(
detector,
reid_model,
mtmct_dir,
mtmct_cfg_file,
scaled=scaled,
threshold=threshold,
output_dir=output_dir,
save_images=True,
save_mot_txts=True)
else:
img_list = get_test_images(image_dir, infer_img=None)
predict_image(
detector,
reid_model,
img_list,
threshold=threshold,
output_dir=output_dir,
save_images=True)
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector_func = 'SDE_Detector'
......@@ -945,18 +1038,45 @@ def main():
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, reid_model, FLAGS.camera_id)
predict_video(
detector,
reid_model,
FLAGS.video_file,
scaled=FLAGS.scaled,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts,
draw_center_traj=FLAGS.draw_center_traj,
secs_interval=FLAGS.secs_interval,
do_entrance_counting=FLAGS.do_entrance_counting,
camera_id=FLAGS.camera_id)
elif FLAGS.mtmct_dir is not None:
mtmct_cfg_file = FLAGS.mtmct_cfg
with open(mtmct_cfg_file) as f:
mtmct_cfg = yaml.safe_load(f)
predict_mtmct(detector, reid_model, FLAGS.mtmct_dir, mtmct_cfg)
predict_mtmct(
detector,
reid_model,
FLAGS.mtmct_dir,
mtmct_cfg,
scaled=FLAGS.scaled,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, reid_model, img_list)
predict_image(
detector,
reid_model,
img_list,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
save_images=FLAGS.save_images,
run_benchmark=FLAGS.run_benchmark)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册