From 4e51747527f996466b9510bd2f9a88dbc5350bda Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 17 Nov 2021 18:59:37 +0800 Subject: [PATCH] support draw traj in mct (#4604) --- deploy/pptracking/python/mot_jde_infer.py | 2 +- deploy/pptracking/python/utils.py | 2 +- deploy/pptracking/python/visualize.py | 55 +++++++++++++++-------- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/deploy/pptracking/python/mot_jde_infer.py b/deploy/pptracking/python/mot_jde_infer.py index 857168dba..76091f794 100644 --- a/deploy/pptracking/python/mot_jde_infer.py +++ b/deploy/pptracking/python/mot_jde_infer.py @@ -224,7 +224,7 @@ def predict_video(detector, camera_id): entrance = None records = None if FLAGS.draw_center_traj: - center_traj = {} + center_traj = [{} for i in range(num_classes)] if num_classes == 1: id_set = set() diff --git a/deploy/pptracking/python/utils.py b/deploy/pptracking/python/utils.py index dfa8e687f..fe5000d2d 100644 --- a/deploy/pptracking/python/utils.py +++ b/deploy/pptracking/python/utils.py @@ -129,7 +129,7 @@ def argsparser(): parser.add_argument( "--secs_interval", type=int, - default=10, + default=2, help="The seconds interval to count after tracking") parser.add_argument( "--draw_center_traj", diff --git a/deploy/pptracking/python/visualize.py b/deploy/pptracking/python/visualize.py index 22db57a94..f419581f4 100644 --- a/deploy/pptracking/python/visualize.py +++ b/deploy/pptracking/python/visualize.py @@ -213,6 +213,33 @@ def plot_tracking_dict(image, radius = max(5, int(im_w / 140.)) + if num_classes == 1: + start = records[-1].find('Total') + end = records[-1].find('In') + cv2.putText( + im, + records[-1][start:end - 2], (0, int(40 * text_scale)), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 0, 255), + thickness=2) + + if num_classes == 1 and do_entrance_counting: + entrance_line = tuple(map(int, entrance)) + cv2.rectangle( + im, + entrance_line[0:2], + entrance_line[2:4], + color=(0, 255, 255), + thickness=line_thickness) + # find start location for entrance counting data + start = records[-1].find('In') + cv2.putText( + im, + records[-1][start:-1], (0, int(60 * text_scale)), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 0, 255), + thickness=2) + for cls_id in range(num_classes): tlwhs = tlwhs_dict[cls_id] obj_ids = obj_ids_dict[cls_id] @@ -233,9 +260,9 @@ def plot_tracking_dict(image, obj_id = int(obj_ids[i]) if center_traj is not None: record_id.add(obj_id) - if obj_id not in center_traj: - center_traj[obj_id] = deque(maxlen=30) - center_traj[obj_id].append(center) + if obj_id not in center_traj[cls_id]: + center_traj[cls_id][obj_id] = deque(maxlen=30) + center_traj[cls_id][obj_id].append(center) id_text = '{}'.format(int(obj_id)) if ids2names != []: @@ -266,19 +293,11 @@ def plot_tracking_dict(image, cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=text_thickness) - if num_classes == 1 and do_entrance_counting: - entrance_line = tuple(map(int, entrance)) - cv2.rectangle( - im, - entrance_line[0:2], - entrance_line[2:4], - color=(0, 255, 255), - thickness=line_thickness) - - if center_traj is not None: - for i in center_traj.keys(): - if i not in record_id: - continue - for point in center_traj[i]: - cv2.circle(im, point, 3, (0, 0, 255), -1) + if center_traj is not None: + for traj in center_traj: + for i in traj.keys(): + if i not in record_id: + continue + for point in traj[i]: + cv2.circle(im, point, 3, (0, 0, 255), -1) return im -- GitLab