From 3fca7404d0a253545c4f8e8e2c4d69ff1016a42e Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Thu, 9 Dec 2021 18:44:04 +0800 Subject: [PATCH] [cherry-pick] Add pptracking api (#4861) * fix plot_tracking_dict, test=document_fix * add mot jde api for pptracking, test=document_fix --- deploy/pptracking/python/README.md | 82 ++++++++- deploy/pptracking/python/mot_jde_infer.py | 130 +++++++++---- deploy/pptracking/python/mot_sde_infer.py | 212 +++++++++++++++++----- 3 files changed, 339 insertions(+), 85 deletions(-) diff --git a/deploy/pptracking/python/README.md b/deploy/pptracking/python/README.md index 25e09ce6b..0dcbf61d9 100644 --- a/deploy/pptracking/python/README.md +++ b/deploy/pptracking/python/README.md @@ -111,6 +111,7 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_ - DeepSORT算法不支持多类别跟踪,只支持单类别跟踪,且ReID模型最好是与检测模型同一类别的物体训练过的,比如行人跟踪最好使用行人ReID模型,车辆跟踪最好使用车辆ReID模型。 + ## 3. 跨境跟踪模型的导出和预测 ### 3.1 导出预测模型 Step 1:下载导出的检测模型 @@ -129,11 +130,15 @@ tar -xvf deepsort_pplcnet_vehicle.tar ### 3.2 用导出的模型基于Python去做跨镜头跟踪 ```bash +# 下载demo测试视频 +wget https://paddledet.bj.bcebos.com/data/mot/demo/mtmct-demo.tar +tar -xvf mtmct-demo.tar + # 用导出的PicoDet车辆检测模型和PPLCNet车辆ReID模型 -python deploy/pptracking/python/mot_sde_infer.py --model_dir=picodet_l_640_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir={your mtmct scene video folder} --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --threshold=0.5 --save_mot_txts --save_images +python deploy/pptracking/python/mot_sde_infer.py --model_dir=picodet_l_640_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir=mtmct-demo --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --threshold=0.5 --save_mot_txts --save_images # 用导出的PP-YOLOv2车辆检测模型和PPLCNet车辆ReID模型 -python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir={your mtmct scene video folder} --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --threshold=0.5 --save_mot_txts --save_images +python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/ --reid_model_dir=deepsort_pplcnet_vehicle/ --mtmct_dir=mtmct-demo --mtmct_cfg=mtmct_cfg --device=GPU --scaled=True --threshold=0.5 --save_mot_txts --save_images ``` **注意:** @@ -146,7 +151,78 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_ - `--mtmct_cfg`是MTMCT预测的某个场景的配置文件,里面包含该一些trick操作的开关和该场景摄像头相关设置的文件路径,用户可以自行更改相关路径以及设置某些操作是否启用。 -## 参数说明: +## 4. API调用方式: + +### 4.1 FairMOT模型API调用 +``` +import mot_jde_infer + +# 1.model config and weights +model_dir = 'fairmot_hrnetv2_w18_dlafpn_30e_576x320/' + +# 2.inference data +video_file = 'test.mp4' +image_dir = None + +# 3.other settings +device = 'CPU' # device should be CPU, GPU or XPU +threshold = 0.3 +output_dir = 'output' + +# mot predict +mot_jde_infer.predict_naive(model_dir, video_file, image_dir, device, threshold, output_dir) +``` +**注意:** + - 以上代码必须进入目录`PaddleDetection/deploy/pptracking/python`下执行。 + - 支持对视频和图片文件夹进行预测,不支持单张图的预测,`video_file`或`image_dir`不能同时为None,推荐使用`video_file`,而`image_dir`需直接存放命名顺序规范的图片。 + - 默认会保存跟踪结果可视化后的图片和视频,以及跟踪结果txt文件,默认不会进行轨迹可视化和流量统计。 + + +### 4.2 DeepSORT模型API调用 +``` +import mot_sde_infer + +# 1.model config and weights +model_dir = 'ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/' +reid_model_dir = 'deepsort_pplcnet_vehicle/' + +# 2.inference data +video_file = 'test.mp4' +image_dir = None + +# 3.other settings +scaled = True # set False only when use JDE YOLOv3 +device = 'CPU' # device should be CPU, GPU or XPU +threshold = 0.3 +output_dir = 'output' + +# 4. MTMCT settings, default None +mtmct_dir = None +mtmct_cfg = None + +# mot predict +mot_sde_infer.predict_naive(model_dir, + reid_model_dir, + video_file, + image_dir, + mtmct_dir, + mtmct_cfg, + scaled, + device, + threshold, + output_dir) +``` +**注意:** + - 以上代码必须进入目录`PaddleDetection/deploy/pptracking/python`下执行。 + - 支持对视频和图片文件夹进行预测,不支持单张图的预测,`video_file`或`image_dir`或`--mtmct_dir`不能同时为None,推荐使用`video_file`,而`image_dir`需直接存放命名顺序规范的图片,`--mtmct_dir`不为None表示是进行的MTMCT跨镜头跟踪任务。 + - 默认会保存跟踪结果可视化后的图片和视频,以及跟踪结果txt文件,默认不会进行轨迹可视化和流量统计。 + - `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。 + - `--mtmct_dir`是MTMCT预测的某个场景的文件夹名字,里面包含该场景不同摄像头拍摄视频的图片文件夹,其数量至少为两个。 + - `--mtmct_cfg`是MTMCT预测的某个场景的配置文件,里面包含该一些trick操作的开关和该场景摄像头相关设置的文件路径,用户可以自行更改相关路径以及设置某些操作是否启用。 + - 开启MTMCT预测必须将`video_file`和`image_dir`同时设置为None,且`--mtmct_dir`和`--mtmct_cfg`都必须不为None。 + + +## 5. 参数说明: | 参数 | 是否必须|含义 | |-------|-------|----------| diff --git a/deploy/pptracking/python/mot_jde_infer.py b/deploy/pptracking/python/mot_jde_infer.py index bcad3a241..73ab25f78 100644 --- a/deploy/pptracking/python/mot_jde_infer.py +++ b/deploy/pptracking/python/mot_jde_infer.py @@ -167,7 +167,12 @@ class JDE_Detector(Detector): return online_tlwhs, online_scores, online_ids -def predict_image(detector, image_list): +def predict_image(detector, + image_list, + threshold, + output_dir, + save_images=True, + run_benchmark=False): results = [] num_classes = detector.num_classes data_type = 'mcmot' if num_classes > 1 else 'mot' @@ -176,13 +181,11 @@ def predict_image(detector, image_list): image_list.sort() for frame_id, img_file in enumerate(image_list): frame = cv2.imread(img_file) - if FLAGS.run_benchmark: + if run_benchmark: # warmup - detector.predict( - [img_file], FLAGS.threshold, repeats=10, add_timer=False) + detector.predict([img_file], threshold, repeats=10, add_timer=False) # run benchmark - detector.predict( - [img_file], FLAGS.threshold, repeats=10, add_timer=True) + detector.predict([img_file], threshold, repeats=10, add_timer=True) cm, gm, gu = get_current_memory_mb() detector.cpu_mem += cm detector.gpu_mem += gm @@ -190,7 +193,7 @@ def predict_image(detector, image_list): print('Test iter {}, file name:{}'.format(frame_id, img_file)) else: online_tlwhs, online_scores, online_ids = detector.predict( - [img_file], FLAGS.threshold) + [img_file], threshold) online_im = plot_tracking_dict( frame, num_classes, @@ -199,22 +202,32 @@ def predict_image(detector, image_list): online_scores, frame_id=frame_id, ids2names=ids2names) - if FLAGS.save_images: - if not os.path.exists(FLAGS.output_dir): - os.makedirs(FLAGS.output_dir) + if save_images: + if not os.path.exists(output_dir): + os.makedirs(output_dir) img_name = os.path.split(img_file)[-1] - out_path = os.path.join(FLAGS.output_dir, img_name) + out_path = os.path.join(output_dir, img_name) cv2.imwrite(out_path, online_im) print("save result to: " + out_path) -def predict_video(detector, camera_id): +def predict_video(detector, + video_file, + threshold, + output_dir, + save_images=True, + save_mot_txts=True, + draw_center_traj=False, + secs_interval=10, + do_entrance_counting=False, + camera_id=-1): video_name = 'mot_output.mp4' if camera_id != -1: capture = cv2.VideoCapture(camera_id) else: - capture = cv2.VideoCapture(FLAGS.video_file) - video_name = os.path.split(FLAGS.video_file)[-1] + capture = cv2.VideoCapture(video_file) + video_name = os.path.split(video_file)[-1] + # Get Video info : resolution, fps, frame count width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) @@ -222,10 +235,10 @@ def predict_video(detector, camera_id): frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) print("fps: %d, frame_count: %d" % (fps, frame_count)) - if not os.path.exists(FLAGS.output_dir): - os.makedirs(FLAGS.output_dir) - out_path = os.path.join(FLAGS.output_dir, video_name) - if not FLAGS.save_images: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + out_path = os.path.join(output_dir, video_name) + if not save_images: video_format = 'mp4v' fourcc = cv2.VideoWriter_fourcc(*video_format) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) @@ -238,7 +251,7 @@ def predict_video(detector, camera_id): center_traj = None entrance = None records = None - if FLAGS.draw_center_traj: + if draw_center_traj: center_traj = [{} for i in range(num_classes)] if num_classes == 1: @@ -257,8 +270,8 @@ def predict_video(detector, camera_id): if not ret: break timer.tic() - online_tlwhs, online_scores, online_ids = detector.predict( - [frame], FLAGS.threshold) + online_tlwhs, online_scores, online_ids = detector.predict([frame], + threshold) timer.toc() for cls_id in range(num_classes): @@ -271,9 +284,9 @@ def predict_video(detector, camera_id): result = (frame_id + 1, online_tlwhs[0], online_scores[0], online_ids[0]) statistic = flow_statistic( - result, FLAGS.secs_interval, FLAGS.do_entrance_counting, - video_fps, entrance, id_set, interval_id_set, in_id_list, - out_id_list, prev_center, records, data_type, num_classes) + result, secs_interval, do_entrance_counting, video_fps, + entrance, id_set, interval_id_set, in_id_list, out_id_list, + prev_center, records, data_type, num_classes) id_set = statistic['id_set'] interval_id_set = statistic['interval_id_set'] in_id_list = statistic['in_id_list'] @@ -281,7 +294,7 @@ def predict_video(detector, camera_id): prev_center = statistic['prev_center'] records = statistic['records'] - elif num_classes > 1 and FLAGS.do_entrance_counting: + elif num_classes > 1 and do_entrance_counting: raise NotImplementedError( 'Multi-class flow counting is not implemented now!') im = plot_tracking_dict( @@ -293,13 +306,13 @@ def predict_video(detector, camera_id): frame_id=frame_id, fps=fps, ids2names=ids2names, - do_entrance_counting=FLAGS.do_entrance_counting, + do_entrance_counting=do_entrance_counting, entrance=entrance, records=records, center_traj=center_traj) - if FLAGS.save_images: - save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) + if save_images: + save_dir = os.path.join(output_dir, video_name.split('.')[-2]) if not os.path.exists(save_dir): os.makedirs(save_dir) cv2.imwrite( @@ -313,24 +326,23 @@ def predict_video(detector, camera_id): cv2.imshow('Tracking Detection', im) if cv2.waitKey(1) & 0xFF == ord('q'): break - if FLAGS.save_mot_txts: - result_filename = os.path.join(FLAGS.output_dir, + if save_mot_txts: + result_filename = os.path.join(output_dir, video_name.split('.')[-2] + '.txt') write_mot_results(result_filename, results, data_type, num_classes) if num_classes == 1: result_filename = os.path.join( - FLAGS.output_dir, - video_name.split('.')[-2] + '_flow_statistic.txt') + output_dir, video_name.split('.')[-2] + '_flow_statistic.txt') f = open(result_filename, 'w') for line in records: f.write(line) print('Flow statistic save in {}'.format(result_filename)) f.close() - if FLAGS.save_images: - save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) + if save_images: + save_dir = os.path.join(output_dir, video_name.split('.')[-2]) cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir, out_path) os.system(cmd_str) @@ -339,6 +351,36 @@ def predict_video(detector, camera_id): writer.release() +def predict_naive(model_dir, + video_file, + image_dir, + device='gpu', + threshold=0.5, + output_dir='output'): + pred_config = PredictConfig(model_dir) + detector = JDE_Detector(pred_config, model_dir, device=device.upper()) + + if video_file is not None: + predict_video( + detector, + video_file, + threshold=threshold, + output_dir=output_dir, + save_images=True, + save_mot_txts=True, + draw_center_traj=False, + secs_interval=10, + do_entrance_counting=False) + else: + img_list = get_test_images(image_dir, infer_img=None) + predict_image( + detector, + img_list, + threshold=threshold, + output_dir=output_dir, + save_images=True) + + def main(): pred_config = PredictConfig(FLAGS.model_dir) detector = JDE_Detector( @@ -355,11 +397,27 @@ def main(): # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: - predict_video(detector, FLAGS.camera_id) + predict_video( + detector, + FLAGS.video_file, + threshold=FLAGS.threshold, + output_dir=FLAGS.output_dir, + save_images=FLAGS.save_images, + save_mot_txts=FLAGS.save_mot_txts, + draw_center_traj=FLAGS.draw_center_traj, + secs_interval=FLAGS.secs_interval, + do_entrance_counting=FLAGS.do_entrance_counting, + camera_id=FLAGS.camera_id) else: # predict from image img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) - predict_image(detector, img_list) + predict_image( + detector, + img_list, + threshold=FLAGS.threshold, + output_dir=FLAGS.output_dir, + save_images=FLAGS.save_images, + run_benchmark=FLAGS.run_benchmark) if not FLAGS.run_benchmark: detector.det_times.info(average=True) else: diff --git a/deploy/pptracking/python/mot_sde_infer.py b/deploy/pptracking/python/mot_sde_infer.py index 1a3110a46..bbf05f27a 100644 --- a/deploy/pptracking/python/mot_sde_infer.py +++ b/deploy/pptracking/python/mot_sde_infer.py @@ -316,6 +316,8 @@ class SDE_DetectorPicoDet(DetectorPicoDet): self.det_times.preprocess_time_s.end() self.det_times.inference_time_s.start() + np_score_list, np_boxes_list = [], [] + # model prediction for i in range(repeats): self.predictor.run() @@ -549,26 +551,33 @@ class SDE_ReID(object): return tracking_outs -def predict_image(detector, reid_model, image_list): +def predict_image(detector, + reid_model, + image_list, + threshold, + output_dir, + scaled=True, + save_images=True, + run_benchmark=False): image_list.sort() for i, img_file in enumerate(image_list): frame = cv2.imread(img_file) ori_image_shape = list(frame.shape[:2]) - if FLAGS.run_benchmark: + if run_benchmark: # warmup pred_dets, pred_xyxys = detector.predict( [img_file], ori_image_shape, - FLAGS.threshold, - FLAGS.scaled, + threshold, + scaled, repeats=10, add_timer=False) # run benchmark pred_dets, pred_xyxys = detector.predict( [img_file], ori_image_shape, - FLAGS.threshold, - FLAGS.scaled, + threshold, + scaled, repeats=10, add_timer=True) @@ -579,7 +588,7 @@ def predict_image(detector, reid_model, image_list): print('Test iter {}, file name:{}'.format(i, img_file)) else: pred_dets, pred_xyxys = detector.predict( - [img_file], ori_image_shape, FLAGS.threshold, FLAGS.scaled) + [img_file], ori_image_shape, threshold, scaled) if len(pred_dets) == 1 and np.sum(pred_dets) == 0: print('Frame {} has no object, try to modify score threshold.'. @@ -589,7 +598,7 @@ def predict_image(detector, reid_model, image_list): # reid process crops = reid_model.get_crops(pred_xyxys, frame) - if FLAGS.run_benchmark: + if run_benchmark: # warmup tracking_outs = reid_model.predict( crops, pred_dets, repeats=10, add_timer=False) @@ -607,22 +616,34 @@ def predict_image(detector, reid_model, image_list): online_im = plot_tracking( frame, online_tlwhs, online_ids, online_scores, frame_id=i) - if FLAGS.save_images: - if not os.path.exists(FLAGS.output_dir): - os.makedirs(FLAGS.output_dir) + if save_images: + if not os.path.exists(output_dir): + os.makedirs(output_dir) img_name = os.path.split(img_file)[-1] - out_path = os.path.join(FLAGS.output_dir, img_name) + out_path = os.path.join(output_dir, img_name) cv2.imwrite(out_path, online_im) print("save result to: " + out_path) -def predict_video(detector, reid_model, camera_id): +def predict_video(detector, + reid_model, + video_file, + scaled, + threshold, + output_dir, + save_images=True, + save_mot_txts=True, + draw_center_traj=False, + secs_interval=10, + do_entrance_counting=False, + camera_id=-1): + video_name = 'mot_output.mp4' if camera_id != -1: capture = cv2.VideoCapture(camera_id) - video_name = 'mot_output.mp4' else: - capture = cv2.VideoCapture(FLAGS.video_file) - video_name = os.path.split(FLAGS.video_file)[-1] + capture = cv2.VideoCapture(video_file) + video_name = os.path.split(video_file)[-1] + # Get Video info : resolution, fps, frame count width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) @@ -630,10 +651,10 @@ def predict_video(detector, reid_model, camera_id): frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) print("fps: %d, frame_count: %d" % (fps, frame_count)) - if not os.path.exists(FLAGS.output_dir): - os.makedirs(FLAGS.output_dir) - out_path = os.path.join(FLAGS.output_dir, video_name) - if not FLAGS.save_images: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + out_path = os.path.join(output_dir, video_name) + if not save_images: video_format = 'mp4v' fourcc = cv2.VideoWriter_fourcc(*video_format) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) @@ -656,7 +677,7 @@ def predict_video(detector, reid_model, camera_id): timer.tic() ori_image_shape = list(frame.shape[:2]) pred_dets, pred_xyxys = detector.predict([frame], ori_image_shape, - FLAGS.threshold, FLAGS.scaled) + threshold, scaled) if len(pred_dets) == 1 and np.sum(pred_dets) == 0: print('Frame {} has no object, try to modify score threshold.'. @@ -677,9 +698,9 @@ def predict_video(detector, reid_model, camera_id): # NOTE: just implement flow statistic for one class result = (frame_id + 1, online_tlwhs, online_scores, online_ids) statistic = flow_statistic( - result, FLAGS.secs_interval, FLAGS.do_entrance_counting, - video_fps, entrance, id_set, interval_id_set, in_id_list, - out_id_list, prev_center, records) + result, secs_interval, do_entrance_counting, video_fps, + entrance, id_set, interval_id_set, in_id_list, out_id_list, + prev_center, records) id_set = statistic['id_set'] interval_id_set = statistic['interval_id_set'] in_id_list = statistic['in_id_list'] @@ -697,11 +718,11 @@ def predict_video(detector, reid_model, camera_id): online_scores, frame_id=frame_id, fps=fps, - do_entrance_counting=FLAGS.do_entrance_counting, + do_entrance_counting=do_entrance_counting, entrance=entrance) - if FLAGS.save_images: - save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) + if save_images: + save_dir = os.path.join(output_dir, video_name.split('.')[-2]) if not os.path.exists(save_dir): os.makedirs(save_dir) cv2.imwrite( @@ -717,21 +738,21 @@ def predict_video(detector, reid_model, camera_id): if cv2.waitKey(1) & 0xFF == ord('q'): break - if FLAGS.save_mot_txts: - result_filename = os.path.join(FLAGS.output_dir, + if save_mot_txts: + result_filename = os.path.join(output_dir, video_name.split('.')[-2] + '.txt') write_mot_results(result_filename, results) result_filename = os.path.join( - FLAGS.output_dir, video_name.split('.')[-2] + '_flow_statistic.txt') + output_dir, video_name.split('.')[-2] + '_flow_statistic.txt') f = open(result_filename, 'w') for line in records: f.write(line) print('Flow statistic save in {}'.format(result_filename)) f.close() - if FLAGS.save_images: - save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) + if save_images: + save_dir = os.path.join(output_dir, video_name.split('.')[-2]) cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir, out_path) os.system(cmd_str) @@ -740,8 +761,16 @@ def predict_video(detector, reid_model, camera_id): writer.release() -def predict_mtmct_seq(detector, reid_model, seq_name, output_dir): - fpath = os.path.join(FLAGS.mtmct_dir, seq_name) +def predict_mtmct_seq(detector, + reid_model, + mtmct_dir, + seq_name, + scaled, + threshold, + output_dir, + save_images=True, + save_mot_txts=True): + fpath = os.path.join(mtmct_dir, seq_name) if os.path.exists(os.path.join(fpath, 'img1')): fpath = os.path.join(fpath, 'img1') @@ -756,13 +785,13 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir): len(image_list), seq_name)) for frame_id, img_file in enumerate(image_list): - if frame_id % 40 == 0: + if frame_id % 10 == 0: print('Processing frame {} of seq {}.'.format(frame_id, seq_name)) frame = cv2.imread(os.path.join(fpath, img_file)) ori_image_shape = list(frame.shape[:2]) frame_path = os.path.join(fpath, img_file) pred_dets, pred_xyxys = detector.predict([frame_path], ori_image_shape, - FLAGS.threshold, FLAGS.scaled) + threshold, scaled) if len(pred_dets) == 1 and np.sum(pred_dets) == 0: print('Frame {} has no object, try to modify score threshold.'. @@ -791,21 +820,29 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir): results[0].append( (frame_id + 1, online_tlwhs, online_scores, online_ids)) - if FLAGS.save_images: + if save_images: save_dir = os.path.join(output_dir, seq_name) if not os.path.exists(save_dir): os.makedirs(save_dir) img_name = os.path.split(img_file)[-1] out_path = os.path.join(save_dir, img_name) cv2.imwrite(out_path, online_im) - if FLAGS.save_mot_txts: + if save_mot_txts: result_filename = os.path.join(output_dir, seq_name + '.txt') write_mot_results(result_filename, results) return mot_features_dict -def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg): +def predict_mtmct(detector, + reid_model, + mtmct_dir, + mtmct_cfg, + scaled, + threshold, + output_dir, + save_images=True, + save_mot_txts=True): MTMCT = mtmct_cfg['MTMCT'] assert MTMCT == True, 'predict_mtmct should be used for MTMCT.' @@ -832,7 +869,6 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg): mot_list_breaks = [] cid_tid_dict = dict() - output_dir = FLAGS.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) seqs = os.listdir(mtmct_dir) @@ -852,8 +888,9 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg): print('{} is not a image folder.'.format(fpath)) continue - mot_features_dict = predict_mtmct_seq(detector, reid_model, seq, - output_dir) + mot_features_dict = predict_mtmct_seq( + detector, reid_model, mtmct_dir, seq, scaled, threshold, output_dir, + save_images, save_mot_txts) cid = int(re.sub('[a-z,A-Z]', "", seq)) tid_data, mot_list_break = trajectory_fusion( @@ -911,6 +948,62 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg): print_mtmct_result(data_root_gt, pred_mtmct_file) +def predict_naive(model_dir, + reid_model_dir, + video_file, + image_dir, + mtmct_dir=None, + mtmct_cfg=None, + scaled=True, + device='gpu', + threshold=0.5, + output_dir='output'): + pred_config = PredictConfig(model_dir) + detector_func = 'SDE_Detector' + if pred_config.arch == 'PicoDet': + detector_func = 'SDE_DetectorPicoDet' + detector = eval(detector_func)(pred_config, model_dir, device=device) + + pred_config = PredictConfig(reid_model_dir) + reid_model = SDE_ReID(pred_config, reid_model_dir, device=device) + + if video_file is not None: + predict_video( + detector, + reid_model, + video_file, + scaled=scaled, + threshold=threshold, + output_dir=output_dir, + save_images=True, + save_mot_txts=True, + draw_center_traj=False, + secs_interval=10, + do_entrance_counting=False) + elif mtmct_dir is not None: + with open(mtmct_cfg) as f: + mtmct_cfg_file = yaml.safe_load(f) + predict_mtmct( + detector, + reid_model, + mtmct_dir, + mtmct_cfg_file, + scaled=scaled, + threshold=threshold, + output_dir=output_dir, + save_images=True, + save_mot_txts=True) + else: + img_list = get_test_images(image_dir, infer_img=None) + predict_image( + detector, + reid_model, + img_list, + threshold=threshold, + output_dir=output_dir, + save_images=True) + + def main(): pred_config = PredictConfig(FLAGS.model_dir) detector_func = 'SDE_Detector' @@ -945,18 +1038,45 @@ def main(): # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: - predict_video(detector, reid_model, FLAGS.camera_id) + predict_video( + detector, + reid_model, + FLAGS.video_file, + scaled=FLAGS.scaled, + threshold=FLAGS.threshold, + output_dir=FLAGS.output_dir, + save_images=FLAGS.save_images, + save_mot_txts=FLAGS.save_mot_txts, + draw_center_traj=FLAGS.draw_center_traj, + secs_interval=FLAGS.secs_interval, + do_entrance_counting=FLAGS.do_entrance_counting, + camera_id=FLAGS.camera_id) elif FLAGS.mtmct_dir is not None: mtmct_cfg_file = FLAGS.mtmct_cfg with open(mtmct_cfg_file) as f: mtmct_cfg = yaml.safe_load(f) - predict_mtmct(detector, reid_model, FLAGS.mtmct_dir, mtmct_cfg) - + predict_mtmct( + detector, + reid_model, + FLAGS.mtmct_dir, + mtmct_cfg, + scaled=FLAGS.scaled, + threshold=FLAGS.threshold, + output_dir=FLAGS.output_dir, + save_images=FLAGS.save_images, + save_mot_txts=FLAGS.save_mot_txts) else: # predict from image img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) - predict_image(detector, reid_model, img_list) + predict_image( + detector, + reid_model, + img_list, + threshold=FLAGS.threshold, + output_dir=FLAGS.output_dir, + save_images=FLAGS.save_images, + run_benchmark=FLAGS.run_benchmark) if not FLAGS.run_benchmark: detector.det_times.info(average=True) -- GitLab