diff --git a/deploy/pphuman/README.md b/deploy/pphuman/README.md index 2929eb6f11215d1687de46a1505926951a7387ec..f8949e08d8a02b8e95940375302f62e279e3f8a7 100644 --- a/deploy/pphuman/README.md +++ b/deploy/pphuman/README.md @@ -39,7 +39,7 @@ PP-Human提供了目标检测、属性识别、行为识别、ReID预训练模 | 任务 | 适用场景 | 精度 | 预测速度(FPS) | 预测部署模型 | | :---------: |:---------: |:--------------- | :-------: | :------: | | 目标检测 | 图片/视频输入 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip) | -| 属性识别 | 图片/视频输入 属性识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/strongbaseline_r50_30e_pa100k.tar) | +| 属性识别 | 图片/视频输入 属性识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/strongbaseline_r50_30e_pa100k.zip) | | 关键点检测 | 视频输入 行为识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/dark_hrnet_w32_256x192.zip) | 行为识别 | 视频输入 行为识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/STGCN.zip) | | ReID | 视频输入 跨镜跟踪 | - | - | [下载链接]() | @@ -117,7 +117,8 @@ python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml - | --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False | | --cpu_threads | Option| 设置cpu线程数,默认为1 | | --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False | - +| --do_entrance_counting | Option | 是否统计出入口流量,默认为False | +| --draw_center_traj | Option | 是否绘制跟踪轨迹,默认为False | ## 三、方案介绍 @@ -130,13 +131,13 @@ PP-Human整体方案如下图所示 ### 1. 目标检测 - 采用PP-YOLOE L 作为目标检测模型 -- 详细文档参考[PP-YOLOE](../../configs/ppyoloe/) +- 详细文档参考[PP-YOLOE](../../configs/ppyoloe/)和[检测跟踪文档](docs/mot.md) ### 2. 多目标跟踪 - 采用SDE方案完成多目标跟踪 - 检测模型使用PP-YOLOE L - 跟踪模块采用Bytetrack方案 -- 详细文档参考[Bytetrack](configs/mot/bytetrack) +- 详细文档参考[Bytetrack](../../configs/mot/bytetrack)和[检测跟踪文档](docs/mot.md) ### 3. 跨镜跟踪 - 使用PP-YOLOE + Bytetrack得到单镜头多目标跟踪轨迹 diff --git a/deploy/pphuman/config/infer_cfg.yml b/deploy/pphuman/config/infer_cfg.yml index 9e53523aed7cedaa7f208d960bcbf28d09ae1e92..0d4de94c2bfec0b05db1a90691528808d051bc28 100644 --- a/deploy/pphuman/config/infer_cfg.yml +++ b/deploy/pphuman/config/infer_cfg.yml @@ -5,7 +5,7 @@ visual: True warmup_frame: 50 DET: - model_dir: output_inference/mot_ppyolov3/ + model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/ batch_size: 1 ATTR: @@ -13,7 +13,7 @@ ATTR: batch_size: 8 MOT: - model_dir: output_inference/mot_ppyolov3/ + model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/ tracker_config: deploy/pphuman/config/tracker_config.yml batch_size: 1 diff --git a/deploy/pphuman/docs/images/mot.gif b/deploy/pphuman/docs/images/mot.gif new file mode 100644 index 0000000000000000000000000000000000000000..fb1f86af5aa64760d9aeb0d50215cfb3359a98b6 Binary files /dev/null and b/deploy/pphuman/docs/images/mot.gif differ diff --git a/deploy/pphuman/docs/mot.md b/deploy/pphuman/docs/mot.md new file mode 100644 index 0000000000000000000000000000000000000000..6dbcd3b6af17ce68b6b3a12dfc39512a001157d9 --- /dev/null +++ b/deploy/pphuman/docs/mot.md @@ -0,0 +1,64 @@ +# PP-Human检测跟踪模块 + +行人检测与跟踪在智慧社区,工业巡检,交通监控等方向都具有广泛应用,PP-Human中集成了检测跟踪模块,是关键点检测、属性行为识别等任务的基础。我们提供了预训练模型,用户可以直接下载使用。 + +| 任务 | 算法 | 精度 | 预测速度(ms) |下载链接 | +|:---------------------|:---------:|:------:|:------:| :---------------------------------------------------------------------------------: | +| 行人检测/跟踪 | PP-YOLOE | mAP: 56.3
MOTA: 72.0 | 检测: 28ms
跟踪:33.1ms | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip) | + +1. 检测/跟踪模型精度为MOT17,CrowdHuman,HIEVE和部分业务数据融合训练测试得到 +2. 预测速度为T4 机器上使用TensorRT FP16时的速度 + +## 使用方法 + +1. 从上表链接中下载模型并解压到```./output_inference```路径下 +2. 图片输入时,启动命令如下 +```python +python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml \ + --image_file=test_image.jpg \ + --device=gpu +``` +3. 视频输入时,启动命令如下 +```python +python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml \ + --video_file=test_video.mp4 \ + --device=gpu +``` +4. 若修改模型路径,有以下两种方式: + + - ```./deploy/pphuman/config/infer_cfg.yml```下可以配置不同模型路径,检测和跟踪模型分别对应`DET`和`MOT`字段,修改对应字段下的路径为实际期望的路径即可。 + - 命令行中增加`--model_dir`修改模型路径: +```python +python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml \ + --video_file=test_video.mp4 \ + --device=gpu \ + --model_dir det=ppyoloe/ + --do_entrance_counting \ + --draw_center_traj + +``` +**注意:** + - `--do_entrance_counting`表示是否统计出入口流量,不设置即默认为False,`--draw_center_traj`表示是否绘制跟踪轨迹,不设置即默认为False。注意绘制跟踪轨迹的测试视频最好是静止摄像头拍摄的。 + +测试效果如下: + +
+ +
+ +数据来源及版权归属:天覆科技,感谢提供并开源实际场景数据,仅限学术研究使用 + +## 方案说明 + +1. 目标检测/多目标跟踪获取图片/视频输入中的行人检测框,模型方案为PP-YOLOE,详细文档参考[PP-YOLOE](../../../configs/ppyoloe) +2. 多目标跟踪模型方案基于[ByteTrack](https://arxiv.org/pdf/2110.06864.pdf),采用PP-YOLOE替换原文的YOLOX作为检测器,采用BYTETracker作为跟踪器。 + +## 参考文献 +``` +@article{zhang2021bytetrack, + title={ByteTrack: Multi-Object Tracking by Associating Every Detection Box}, + author={Zhang, Yifu and Sun, Peize and Jiang, Yi and Yu, Dongdong and Yuan, Zehuan and Luo, Ping and Liu, Wenyu and Wang, Xinggang}, + journal={arXiv preprint arXiv:2110.06864}, + year={2021} +} +``` diff --git a/deploy/pphuman/pipe_utils.py b/deploy/pphuman/pipe_utils.py index 094cb6a72fe3f04382ec5228760d81ca22e0847f..b55ac9a867cf027662d9b3d84e39f119f28d123a 100644 --- a/deploy/pphuman/pipe_utils.py +++ b/deploy/pphuman/pipe_utils.py @@ -108,6 +108,21 @@ def argsparser(): default=False, help="If the model is produced by TRT offline quantitative " "calibration, trt_calib_mode need to set True.") + parser.add_argument( + "--do_entrance_counting", + action='store_true', + help="Whether counting the numbers of identifiers entering " + "or getting out from the entrance. Note that only support one-class" + "counting, multi-class counting is coming soon.") + parser.add_argument( + "--secs_interval", + type=int, + default=2, + help="The seconds interval to count after tracking") + parser.add_argument( + "--draw_center_traj", + action='store_true', + help="Whether drawing the trajectory of center") return parser diff --git a/deploy/pphuman/pipeline.py b/deploy/pphuman/pipeline.py index 2090b2a0c91e12a0f065bbdb32c74f647af7f4b1..4d6fa014ae783b61c4464b2e292c5d745a5297d1 100644 --- a/deploy/pphuman/pipeline.py +++ b/deploy/pphuman/pipeline.py @@ -15,6 +15,7 @@ import os import yaml import glob +from collections import defaultdict import cv2 import numpy as np @@ -44,7 +45,8 @@ from python.preprocess import decode_image from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action from pptracking.python.mot_sde_infer import SDE_Detector -from pptracking.python.mot.visualize import plot_tracking +from pptracking.python.mot.visualize import plot_tracking_dict +from pptracking.python.mot.utils import flow_statistic class Pipeline(object): @@ -72,6 +74,11 @@ class Pipeline(object): cpu_threads (int): cpu threads, default as 1 enable_mkldnn (bool): whether to open MKLDNN, default as False output_dir (string): The path of output, default as 'output' + draw_center_traj (bool): Whether drawing the trajectory of center, default as False + secs_interval (int): The seconds interval to count after tracking, default as 10 + do_entrance_counting(bool): Whether counting the numbers of identifiers entering + or getting out from the entrance, default as False,only support single class + counting in MOT. """ def __init__(self, @@ -91,7 +98,10 @@ class Pipeline(object): trt_calib_mode=False, cpu_threads=1, enable_mkldnn=False, - output_dir='output'): + output_dir='output', + draw_center_traj=False, + secs_interval=10, + do_entrance_counting=False): self.multi_camera = False self.is_video = False self.output_dir = output_dir @@ -129,10 +139,18 @@ class Pipeline(object): trt_calib_mode=trt_calib_mode, cpu_threads=cpu_threads, enable_mkldnn=enable_mkldnn, - output_dir=output_dir) + output_dir=output_dir, + draw_center_traj=draw_center_traj, + secs_interval=secs_interval, + do_entrance_counting=do_entrance_counting) if self.is_video: self.predictor.set_file_name(video_file) + self.output_dir = output_dir + self.draw_center_traj = draw_center_traj + self.secs_interval = secs_interval + self.do_entrance_counting = do_entrance_counting + def _parse_input(self, image_file, image_dir, video_file, video_dir, camera_id): @@ -144,6 +162,7 @@ class Pipeline(object): self.multi_camera = False elif video_file is not None: + assert os.path.exists(video_file), "video_file not exists." self.multi_camera = False input = video_file self.is_video = True @@ -222,6 +241,11 @@ class PipePredictor(object): cpu_threads (int): cpu threads, default as 1 enable_mkldnn (bool): whether to open MKLDNN, default as False output_dir (string): The path of output, default as 'output' + draw_center_traj (bool): Whether drawing the trajectory of center, default as False + secs_interval (int): The seconds interval to count after tracking, default as 10 + do_entrance_counting(bool): Whether counting the numbers of identifiers entering + or getting out from the entrance, default as False,only support single class + counting in MOT. """ def __init__(self, @@ -238,7 +262,10 @@ class PipePredictor(object): trt_calib_mode=False, cpu_threads=1, enable_mkldnn=False, - output_dir='output'): + output_dir='output', + draw_center_traj=False, + secs_interval=10, + do_entrance_counting=False): if enable_attr and not cfg.get('ATTR', False): ValueError( @@ -268,6 +295,9 @@ class PipePredictor(object): self.multi_camera = multi_camera self.cfg = cfg self.output_dir = output_dir + self.draw_center_traj = draw_center_traj + self.secs_interval = secs_interval + self.do_entrance_counting = do_entrance_counting self.warmup_frame = self.cfg['warmup_frame'] self.pipeline_res = Result() @@ -298,9 +328,20 @@ class PipePredictor(object): tracker_config = mot_cfg['tracker_config'] batch_size = mot_cfg['batch_size'] self.mot_predictor = SDE_Detector( - model_dir, tracker_config, device, run_mode, batch_size, - trt_min_shape, trt_max_shape, trt_opt_shape, trt_calib_mode, - cpu_threads, enable_mkldnn) + model_dir, + tracker_config, + device, + run_mode, + batch_size, + trt_min_shape, + trt_max_shape, + trt_opt_shape, + trt_calib_mode, + cpu_threads, + enable_mkldnn, + draw_center_traj=draw_center_traj, + secs_interval=secs_interval, + do_entrance_counting=do_entrance_counting) if self.with_attr: attr_cfg = self.cfg['ATTR'] model_dir = attr_cfg['model_dir'] @@ -431,6 +472,7 @@ class PipePredictor(object): height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(capture.get(cv2.CAP_PROP_FPS)) frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) + print("video fps: %d, frame_count: %d" % (fps, frame_count)) if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) @@ -438,6 +480,19 @@ class PipePredictor(object): fourcc = cv2.VideoWriter_fourcc(* 'mp4v') writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) frame_id = 0 + + entrance, records, center_traj = None, None, None + if self.draw_center_traj: + center_traj = [{}] + id_set = set() + interval_id_set = set() + in_id_list = list() + out_id_list = list() + prev_center = dict() + records = list() + entrance = [0, height / 2., width, height / 2.] + video_fps = fps + while (1): if frame_id % 10 == 0: print('frame id: ', frame_id) @@ -457,6 +512,16 @@ class PipePredictor(object): # mot output format: id, class, score, xmin, ymin, xmax, ymax mot_res = parse_mot_res(res) + # flow_statistic only support single class MOT + boxes, scores, ids = res[0] # batch size = 1 in MOT + mot_result = (frame_id + 1, boxes[0], scores[0], + ids[0]) # single class + statistic = flow_statistic( + mot_result, self.secs_interval, self.do_entrance_counting, + video_fps, entrance, id_set, interval_id_set, in_id_list, + out_id_list, prev_center, records) + records = statistic['records'] + # nothing detected if len(mot_res['boxes']) == 0: frame_id += 1 @@ -549,13 +614,21 @@ class PipePredictor(object): if self.cfg['visual']: _, _, fps = self.pipe_timer.get_total_time() im = self.visualize_video(frame, self.pipeline_res, frame_id, - fps) # visualize + fps, entrance, records, + center_traj) # visualize writer.write(im) writer.release() print('save result to {}'.format(out_path)) - def visualize_video(self, image, result, frame_id, fps): + def visualize_video(self, + image, + result, + frame_id, + fps, + entrance=None, + records=None, + center_traj=None): mot_res = copy.deepcopy(result.get('mot')) if mot_res is not None: ids = mot_res['boxes'][:, 0] @@ -567,8 +640,28 @@ class PipePredictor(object): boxes = np.zeros([0, 4]) ids = np.zeros([0]) scores = np.zeros([0]) - image = plot_tracking( - image, boxes, ids, scores, frame_id=frame_id, fps=fps) + + # single class, still need to be defaultdict type for ploting + num_classes = 1 + online_tlwhs = defaultdict(list) + online_scores = defaultdict(list) + online_ids = defaultdict(list) + online_tlwhs[0] = boxes + online_scores[0] = scores + online_ids[0] = ids + + image = plot_tracking_dict( + image, + num_classes, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + fps=fps, + do_entrance_counting=self.do_entrance_counting, + entrance=entrance, + records=records, + center_traj=center_traj) attr_res = result.get('attr') if attr_res is not None: @@ -630,7 +723,8 @@ def main(): FLAGS.video_dir, FLAGS.camera_id, FLAGS.enable_attr, FLAGS.enable_action, FLAGS.device, FLAGS.run_mode, FLAGS.trt_min_shape, FLAGS.trt_max_shape, FLAGS.trt_opt_shape, FLAGS.trt_calib_mode, - FLAGS.cpu_threads, FLAGS.enable_mkldnn, FLAGS.output_dir) + FLAGS.cpu_threads, FLAGS.enable_mkldnn, FLAGS.output_dir, + FLAGS.draw_center_traj, FLAGS.secs_interval, FLAGS.do_entrance_counting) pipeline.run() diff --git a/deploy/pptracking/python/README.md b/deploy/pptracking/python/README.md index 6b75568983ce35ffa75919737703188280fba013..d5c34cdf56efec0f0dd7686d2127c33e584eaf37 100644 --- a/deploy/pptracking/python/README.md +++ b/deploy/pptracking/python/README.md @@ -35,10 +35,21 @@ wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4 # Python预测视频 python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video_file=mot17_demo.mp4 --device=GPU --threshold=0.5 --save_mot_txts --save_images ``` + +### 1.3 用导出的模型基于Python去预测,以及进行流量计数、出入口统计和绘制跟踪轨迹等 +```bash +# 下载出入口统计demo视频: +wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/entrance_count_demo.mp4 + +# Python预测视频 +python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video_file=entrance_count_demo.mp4 --device=GPU --do_entrance_counting --draw_center_traj +``` + **注意:** - 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 - 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`。 - `--threshold`表示结果可视化的置信度阈值,默认为0.5,低于该阈值的结果会被过滤掉,为了可视化效果更佳,可根据实际情况自行修改。 + - `--do_entrance_counting`表示是否统计出入口流量,默认为False,`--draw_center_traj`表示是否绘制跟踪轨迹,默认为False。注意绘制跟踪轨迹的测试视频最好是静止摄像头拍摄的。 - 对于多类别或车辆的FairMOT模型的导出和Python预测只需更改相应的config和模型权重即可。如: ```bash job_name=mcfairmot_hrnetv2_w18_dlafpn_30e_576x320_visdrone diff --git a/deploy/pptracking/python/mot/mtmct/postprocess.py b/deploy/pptracking/python/mot/mtmct/postprocess.py index 4bb7ef949bdeaebd153d02e0296df47bbcf909fd..7e338b901fa75716dacc2dc0560bbbeafe53573a 100644 --- a/deploy/pptracking/python/mot/mtmct/postprocess.py +++ b/deploy/pptracking/python/mot/mtmct/postprocess.py @@ -27,7 +27,7 @@ from .utils import parse_pt_gt, parse_pt, compare_dataframes_mtmc from .utils import get_labels, getData, gen_new_mot from .camera_utils import get_labels_with_camera from .zone import Zone -from ..utils import plot_tracking +from ..visualize import plot_tracking __all__ = [ 'trajectory_fusion', @@ -68,8 +68,8 @@ def trajectory_fusion(mot_feature, cid, cid_bias, use_zone=False, zone_path=''): zone_list = [tracklet[f]['zone'] for f in frame_list] feature_list = [ tracklet[f]['feat'] for f in frame_list - if (tracklet[f]['bbox'][3] - tracklet[f]['bbox'][1] - ) * (tracklet[f]['bbox'][2] - tracklet[f]['bbox'][0]) > 2000 + if (tracklet[f]['bbox'][3] - tracklet[f]['bbox'][1]) * + (tracklet[f]['bbox'][2] - tracklet[f]['bbox'][0]) > 2000 ] if len(feature_list) < 2: feature_list = [tracklet[f]['feat'] for f in frame_list] @@ -293,9 +293,9 @@ def save_mtmct_crops(cid_tid_fid_res, for f_id in cid_tid_fid_res[c_id][t_id].keys(): frame_idx = f_id - 1 if f_id > 0 else 0 im_path = os.path.join(infer_dir, all_images[frame_idx]) - + im = cv2.imread(im_path) # (H, W, 3) - + # only select one track track = cid_tid_fid_res[c_id][t_id][f_id][0] diff --git a/deploy/pptracking/python/mot/utils.py b/deploy/pptracking/python/mot/utils.py index 37d39b066671e20c4030eb06e7e5698ecfb4cf68..8bb380af0874e9ee795f7616cc14c0abf55eb320 100644 --- a/deploy/pptracking/python/mot/utils.py +++ b/deploy/pptracking/python/mot/utils.py @@ -20,8 +20,7 @@ import collections __all__ = [ 'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results', - 'preprocess_reid', 'get_crops', 'clip_box', 'scale_coords', 'flow_statistic', - 'plot_tracking' + 'preprocess_reid', 'get_crops', 'clip_box', 'scale_coords', 'flow_statistic' ] @@ -182,7 +181,7 @@ def clip_box(xyxy, ori_image_shape): def get_crops(xyxy, ori_img, w, h): crops = [] xyxy = xyxy.astype(np.int64) - ori_img = ori_img.transpose(1, 0, 2) # [h,w,3]->[w,h,3] + ori_img = ori_img.transpose(1, 0, 2) # [h,w,3]->[w,h,3] for i, bbox in enumerate(xyxy): crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :] crops.append(crop) @@ -197,10 +196,7 @@ def preprocess_reid(imgs, std=[0.229, 0.224, 0.225]): im_batch = [] for img in imgs: - try: - img = cv2.resize(img, (w, h)) - except: - embed() + img = cv2.resize(img, (w, h)) img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 img_mean = np.array(mean).reshape((3, 1, 1)) img_std = np.array(std).reshape((3, 1, 1)) @@ -288,77 +284,3 @@ def flow_statistic(result, "prev_center": prev_center, "records": records } - - -def get_color(idx): - idx = idx * 3 - color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255) - return color - - -def plot_tracking(image, - tlwhs, - obj_ids, - scores=None, - frame_id=0, - fps=0., - ids2names=[], - do_entrance_counting=False, - entrance=None): - im = np.ascontiguousarray(np.copy(image)) - im_h, im_w = im.shape[:2] - - text_scale = max(1, image.shape[1] / 1600.) - text_thickness = 2 - line_thickness = max(1, int(image.shape[1] / 500.)) - - if fps > 0: - _line = 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)) - else: - _line = 'frame: %d num: %d' % (frame_id, len(tlwhs)) - cv2.putText( - im, - _line, - (0, int(15 * text_scale)), - cv2.FONT_HERSHEY_PLAIN, - text_scale, (0, 0, 255), - thickness=2) - - for i, tlwh in enumerate(tlwhs): - x1, y1, w, h = tlwh - intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) - obj_id = int(obj_ids[i]) - id_text = '{}'.format(int(obj_id)) - if ids2names != []: - assert len( - ids2names) == 1, "plot_tracking only supports single classes." - id_text = '{}_'.format(ids2names[0]) + id_text - _line_thickness = 1 if obj_id <= 0 else line_thickness - color = get_color(abs(obj_id)) - cv2.rectangle( - im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness) - cv2.putText( - im, - id_text, (intbox[0], intbox[1] - 10), - cv2.FONT_HERSHEY_PLAIN, - text_scale, (0, 0, 255), - thickness=text_thickness) - - if scores is not None: - text = '{:.2f}'.format(float(scores[i])) - cv2.putText( - im, - text, (intbox[0], intbox[1] + 10), - cv2.FONT_HERSHEY_PLAIN, - text_scale, (0, 255, 255), - thickness=text_thickness) - - if do_entrance_counting: - entrance_line = tuple(map(int, entrance)) - cv2.rectangle( - im, - entrance_line[0:2], - entrance_line[2:4], - color=(0, 255, 255), - thickness=line_thickness) - return im diff --git a/deploy/pptracking/python/mot_jde_infer.py b/deploy/pptracking/python/mot_jde_infer.py index 6ce7a5e4f7e1fbb54503eb39c10754dd5cdee047..afabf5f4b6a573cb8a97af757dc92dafb29a76b2 100644 --- a/deploy/pptracking/python/mot_jde_infer.py +++ b/deploy/pptracking/python/mot_jde_infer.py @@ -31,7 +31,7 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) sys.path.insert(0, parent_path) from mot import JDETracker -from mot.utils import MOTTimer, write_mot_results +from mot.utils import MOTTimer, write_mot_results, flow_statistic from mot.visualize import plot_tracking, plot_tracking_dict # Global dictionary @@ -54,23 +54,38 @@ class JDE_Detector(Detector): trt_calib_mode (bool): If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True cpu_threads (int): cpu threads - enable_mkldnn (bool): whether to open MKLDNN + enable_mkldnn (bool): whether to open MKLDNN + output_dir (string): The path of output, default as 'output' + threshold (float): Score threshold of the detected bbox, default as 0.5 + save_images (bool): Whether to save visualization image results, default as False + save_mot_txts (bool): Whether to save tracking results (txt), default as False + draw_center_traj (bool): Whether drawing the trajectory of center, default as False + secs_interval (int): The seconds interval to count after tracking, default as 10 + do_entrance_counting(bool): Whether counting the numbers of identifiers entering + or getting out from the entrance, default as False,only support single class + counting in MOT. """ - def __init__(self, - model_dir, - tracker_config=None, - device='CPU', - run_mode='paddle', - batch_size=1, - trt_min_shape=1, - trt_max_shape=1088, - trt_opt_shape=608, - trt_calib_mode=False, - cpu_threads=1, - enable_mkldnn=False, - output_dir='output', - threshold=0.5): + def __init__( + self, + model_dir, + tracker_config=None, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1088, + trt_opt_shape=608, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + output_dir='output', + threshold=0.5, + save_images=False, + save_mot_txts=False, + draw_center_traj=False, + secs_interval=10, + do_entrance_counting=False, ): super(JDE_Detector, self).__init__( model_dir=model_dir, device=device, @@ -84,6 +99,12 @@ class JDE_Detector(Detector): enable_mkldnn=enable_mkldnn, output_dir=output_dir, threshold=threshold, ) + self.save_images = save_images + self.save_mot_txts = save_mot_txts + self.draw_center_traj = draw_center_traj + self.secs_interval = secs_interval + self.do_entrance_counting = do_entrance_counting + assert batch_size == 1, "MOT model only supports batch_size=1." self.det_times = Timer(with_tracker=True) self.num_classes = len(self.pred_config.labels) @@ -115,7 +136,7 @@ class JDE_Detector(Detector): return result def tracking(self, det_results): - pred_dets = det_results['pred_dets'] + pred_dets = det_results['pred_dets'] # cls_id, score, x0, y0, x1, y1 pred_embs = det_results['pred_embs'] online_targets_dict = self.tracker.update(pred_dets, pred_embs) @@ -164,7 +185,8 @@ class JDE_Detector(Detector): image_list, run_benchmark=False, repeats=1, - visual=True): + visual=True, + seq_name=None): mot_results = [] num_classes = self.num_classes image_list.sort() @@ -225,7 +247,7 @@ class JDE_Detector(Detector): self.det_times.img_num += 1 if visual: - if frame_id % 10 == 0: + if len(image_list) > 1 and frame_id % 10 == 0: print('Tracking frame {}'.format(frame_id)) frame, _ = decode_image(img_file, {}) @@ -237,7 +259,8 @@ class JDE_Detector(Detector): online_scores, frame_id=frame_id, ids2names=ids2names) - seq_name = image_list[0].split('/')[-2] + if seq_name is None: + seq_name = image_list[0].split('/')[-2] save_dir = os.path.join(self.output_dir, seq_name) if not os.path.exists(save_dir): os.makedirs(save_dir) @@ -264,7 +287,8 @@ class JDE_Detector(Detector): if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) out_path = os.path.join(self.output_dir, video_out_name) - fourcc = cv2.VideoWriter_fourcc(* 'mp4v') + video_format = 'mp4v' + fourcc = cv2.VideoWriter_fourcc(*video_format) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) frame_id = 1 @@ -273,6 +297,23 @@ class JDE_Detector(Detector): num_classes = self.num_classes data_type = 'mcmot' if num_classes > 1 else 'mot' ids2names = self.pred_config.labels + + center_traj = None + entrance = None + records = None + if self.draw_center_traj: + center_traj = [{} for i in range(num_classes)] + if num_classes == 1: + id_set = set() + interval_id_set = set() + in_id_list = list() + out_id_list = list() + prev_center = dict() + records = list() + entrance = [0, height / 2., width, height / 2.] + + video_fps = fps + while (1): ret, frame = capture.read() if not ret: @@ -282,7 +323,9 @@ class JDE_Detector(Detector): frame_id += 1 timer.tic() - mot_results = self.predict_image([frame], visual=False) + seq_name = video_out_name.split('.')[0] + mot_results = self.predict_image( + [frame], visual=False, seq_name=seq_name) timer.toc() online_tlwhs, online_scores, online_ids = mot_results[0] @@ -291,6 +334,16 @@ class JDE_Detector(Detector): (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id], online_ids[cls_id])) + # NOTE: just implement flow statistic for single class + if num_classes == 1: + result = (frame_id + 1, online_tlwhs[0], online_scores[0], + online_ids[0]) + statistic = flow_statistic( + result, self.secs_interval, self.do_entrance_counting, + video_fps, entrance, id_set, interval_id_set, in_id_list, + out_id_list, prev_center, records, data_type, num_classes) + records = statistic['records'] + fps = 1. / timer.duration im = plot_tracking_dict( frame, @@ -300,27 +353,57 @@ class JDE_Detector(Detector): online_scores, frame_id=frame_id, fps=fps, - ids2names=ids2names) + ids2names=ids2names, + do_entrance_counting=self.do_entrance_counting, + entrance=entrance, + records=records, + center_traj=center_traj) writer.write(im) if camera_id != -1: cv2.imshow('Mask Detection', im) if cv2.waitKey(1) & 0xFF == ord('q'): break + + if self.save_mot_txts: + result_filename = os.path.join( + self.output_dir, video_out_name.split('.')[-2] + '.txt') + + write_mot_results(result_filename, results, data_type, num_classes) + + if num_classes == 1: + result_filename = os.path.join( + self.output_dir, + video_out_name.split('.')[-2] + '_flow_statistic.txt') + f = open(result_filename, 'w') + for line in records: + f.write(line) + print('Flow statistic save in {}'.format(result_filename)) + f.close() + writer.release() def main(): detector = JDE_Detector( FLAGS.model_dir, + tracker_config=None, device=FLAGS.device, run_mode=FLAGS.run_mode, + batch_size=1, trt_min_shape=FLAGS.trt_min_shape, trt_max_shape=FLAGS.trt_max_shape, trt_opt_shape=FLAGS.trt_opt_shape, trt_calib_mode=FLAGS.trt_calib_mode, cpu_threads=FLAGS.cpu_threads, - enable_mkldnn=FLAGS.enable_mkldnn) + enable_mkldnn=FLAGS.enable_mkldnn, + output_dir=FLAGS.output_dir, + threshold=FLAGS.threshold, + save_images=FLAGS.save_images, + save_mot_txts=FLAGS.save_mot_txts, + draw_center_traj=FLAGS.draw_center_traj, + secs_interval=FLAGS.secs_interval, + do_entrance_counting=FLAGS.do_entrance_counting, ) # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: diff --git a/deploy/pptracking/python/mot_sde_infer.py b/deploy/pptracking/python/mot_sde_infer.py index 9eac91278bd966487c6e13434fd888cce32dbbe8..62907ba240a34facc1264ebd3b1092c66dcdef99 100644 --- a/deploy/pptracking/python/mot_sde_infer.py +++ b/deploy/pptracking/python/mot_sde_infer.py @@ -33,7 +33,7 @@ sys.path.insert(0, parent_path) from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor from mot_utils import argsparser, Timer, get_current_memory_mb, video2frames, _is_valid_video from mot.tracker import JDETracker, DeepSORTTracker -from mot.utils import MOTTimer, write_mot_results, flow_statistic, get_crops, clip_box +from mot.utils import MOTTimer, write_mot_results, get_crops, clip_box, flow_statistic from mot.visualize import plot_tracking, plot_tracking_dict from mot.mtmct.utils import parse_bias @@ -56,6 +56,15 @@ class SDE_Detector(Detector): calibration, trt_calib_mode need to set True cpu_threads (int): cpu threads enable_mkldnn (bool): whether to open MKLDNN + output_dir (string): The path of output, default as 'output' + threshold (float): Score threshold of the detected bbox, default as 0.5 + save_images (bool): Whether to save visualization image results, default as False + save_mot_txts (bool): Whether to save tracking results (txt), default as False + draw_center_traj (bool): Whether drawing the trajectory of center, default as False + secs_interval (int): The seconds interval to count after tracking, default as 10 + do_entrance_counting(bool): Whether counting the numbers of identifiers entering + or getting out from the entrance, default as False,only support single class + counting in MOT. reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT mtmct_dir (str): MTMCT dir, default None, set for doing MTMCT """ @@ -74,6 +83,11 @@ class SDE_Detector(Detector): enable_mkldnn=False, output_dir='output', threshold=0.5, + save_images=False, + save_mot_txts=False, + draw_center_traj=False, + secs_interval=10, + do_entrance_counting=False, reid_model_dir=None, mtmct_dir=None): super(SDE_Detector, self).__init__( @@ -89,6 +103,12 @@ class SDE_Detector(Detector): enable_mkldnn=enable_mkldnn, output_dir=output_dir, threshold=threshold, ) + self.save_images = save_images + self.save_mot_txts = save_mot_txts + self.draw_center_traj = draw_center_traj + self.secs_interval = secs_interval + self.do_entrance_counting = do_entrance_counting + assert batch_size == 1, "MOT model only supports batch_size=1." self.det_times = Timer(with_tracker=True) self.num_classes = len(self.pred_config.labels) @@ -309,6 +329,7 @@ class SDE_Detector(Detector): feat_data['feat'] = _feat tracking_outs['feat_data'].update({_imgname: feat_data}) return tracking_outs + else: tracking_outs = { 'online_tlwhs': online_tlwhs, @@ -409,7 +430,7 @@ class SDE_Detector(Detector): mot_results.append([online_tlwhs, online_scores, online_ids]) if visual: - if frame_id % 10 == 0: + if len(image_list) > 1 and frame_id % 10 == 0: print('Tracking frame {}'.format(frame_id)) frame, _ = decode_image(img_file, {}) if isinstance(online_tlwhs, defaultdict): @@ -456,13 +477,32 @@ class SDE_Detector(Detector): if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) out_path = os.path.join(self.output_dir, video_out_name) - fourcc = cv2.VideoWriter_fourcc(* 'mp4v') + video_format = 'mp4v' + fourcc = cv2.VideoWriter_fourcc(*video_format) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) frame_id = 1 timer = MOTTimer() - results = defaultdict(list) # support single class and multi classes + results = defaultdict(list) num_classes = self.num_classes + data_type = 'mcmot' if num_classes > 1 else 'mot' + ids2names = self.pred_config.labels + + center_traj = None + entrance = None + records = None + if self.draw_center_traj: + center_traj = [{} for i in range(num_classes)] + if num_classes == 1: + id_set = set() + interval_id_set = set() + in_id_list = list() + out_id_list = list() + prev_center = dict() + records = list() + entrance = [0, height / 2., width, height / 2.] + video_fps = fps + while (1): ret, frame = capture.read() if not ret: @@ -477,10 +517,21 @@ class SDE_Detector(Detector): [frame], visual=False, seq_name=seq_name) timer.toc() - online_tlwhs, online_scores, online_ids = mot_results[ - 0] # bs=1 in MOT model + # bs=1 in MOT model + online_tlwhs, online_scores, online_ids = mot_results[0] + + # NOTE: just implement flow statistic for one class + if num_classes == 1: + result = (frame_id + 1, online_tlwhs[0], online_scores[0], + online_ids[0]) + statistic = flow_statistic( + result, self.secs_interval, self.do_entrance_counting, + video_fps, entrance, id_set, interval_id_set, in_id_list, + out_id_list, prev_center, records, data_type, num_classes) + records = statistic['records'] + fps = 1. / timer.duration - if num_classes == 1 and self.use_reid: + if self.use_deepsort_tracker: # use DeepSORTTracker, only support singe class results[0].append( (frame_id + 1, online_tlwhs, online_scores, online_ids)) @@ -490,7 +541,9 @@ class SDE_Detector(Detector): online_ids, online_scores, frame_id=frame_id, - fps=fps) + fps=fps, + do_entrance_counting=self.do_entrance_counting, + entrance=entrance) else: # use ByteTracker, support multiple class for cls_id in range(num_classes): @@ -505,13 +558,32 @@ class SDE_Detector(Detector): online_scores, frame_id=frame_id, fps=fps, - ids2names=[]) + ids2names=ids2names, + do_entrance_counting=self.do_entrance_counting, + entrance=entrance, + records=records, + center_traj=center_traj) writer.write(im) if camera_id != -1: cv2.imshow('Mask Detection', im) if cv2.waitKey(1) & 0xFF == ord('q'): break + + if self.save_mot_txts: + result_filename = os.path.join( + self.output_dir, video_out_name.split('.')[-2] + '.txt') + write_mot_results(result_filename, results) + + result_filename = os.path.join( + self.output_dir, + video_out_name.split('.')[-2] + '_flow_statistic.txt') + f = open(result_filename, 'w') + for line in records: + f.write(line) + print('Flow statistic save in {}'.format(result_filename)) + f.close() + writer.release() def predict_mtmct(self, mtmct_dir, mtmct_cfg): @@ -623,18 +695,23 @@ def main(): arch = yml_conf['arch'] detector = SDE_Detector( FLAGS.model_dir, - FLAGS.tracker_config, + tracker_config=FLAGS.tracker_config, device=FLAGS.device, run_mode=FLAGS.run_mode, - batch_size=FLAGS.batch_size, + batch_size=1, trt_min_shape=FLAGS.trt_min_shape, trt_max_shape=FLAGS.trt_max_shape, trt_opt_shape=FLAGS.trt_opt_shape, trt_calib_mode=FLAGS.trt_calib_mode, cpu_threads=FLAGS.cpu_threads, enable_mkldnn=FLAGS.enable_mkldnn, - threshold=FLAGS.threshold, output_dir=FLAGS.output_dir, + threshold=FLAGS.threshold, + save_images=FLAGS.save_images, + save_mot_txts=FLAGS.save_mot_txts, + draw_center_traj=FLAGS.draw_center_traj, + secs_interval=FLAGS.secs_interval, + do_entrance_counting=FLAGS.do_entrance_counting, reid_model_dir=FLAGS.reid_model_dir, mtmct_dir=FLAGS.mtmct_dir, ) diff --git a/deploy/pptracking/python/mot_utils.py b/deploy/pptracking/python/mot_utils.py index 04f9420604a3cad3bc6cc8ae0333a12536c2393c..3c2d31c89115b656f54cc6579516c873ad0698cc 100644 --- a/deploy/pptracking/python/mot_utils.py +++ b/deploy/pptracking/python/mot_utils.py @@ -137,6 +137,21 @@ def argsparser(): type=ast.literal_eval, default=True, help='whether to use darkpose to get better keypoint position predict ') + parser.add_argument( + "--do_entrance_counting", + action='store_true', + help="Whether counting the numbers of identifiers entering " + "or getting out from the entrance. Note that only support one-class" + "counting, multi-class counting is coming soon.") + parser.add_argument( + "--secs_interval", + type=int, + default=2, + help="The seconds interval to count after tracking") + parser.add_argument( + "--draw_center_traj", + action='store_true', + help="Whether drawing the trajectory of center") parser.add_argument( "--mtmct_dir", type=str,