diff --git a/deploy/pptracking/python/mot_jde_infer.py b/deploy/pptracking/python/mot_jde_infer.py index 857168dba167ea7f480400ea8adc9face0289dac..76091f7940fde9d366f7a57dd72c1f67e3bc9eca 100644 --- a/deploy/pptracking/python/mot_jde_infer.py +++ b/deploy/pptracking/python/mot_jde_infer.py @@ -224,7 +224,7 @@ def predict_video(detector, camera_id): entrance = None records = None if FLAGS.draw_center_traj: - center_traj = {} + center_traj = [{} for i in range(num_classes)] if num_classes == 1: id_set = set() diff --git a/deploy/pptracking/python/utils.py b/deploy/pptracking/python/utils.py index dfa8e687f61cdbe95f6c6b408d1ca320908080a5..fe5000d2df82faa602051da5462597a90f85c377 100644 --- a/deploy/pptracking/python/utils.py +++ b/deploy/pptracking/python/utils.py @@ -129,7 +129,7 @@ def argsparser(): parser.add_argument( "--secs_interval", type=int, - default=10, + default=2, help="The seconds interval to count after tracking") parser.add_argument( "--draw_center_traj", diff --git a/deploy/pptracking/python/visualize.py b/deploy/pptracking/python/visualize.py index 22db57a941b9176efae78a776d163a8fc3e9c207..f419581f47dd68a0d2a88072b46630fbac3066e1 100644 --- a/deploy/pptracking/python/visualize.py +++ b/deploy/pptracking/python/visualize.py @@ -213,6 +213,33 @@ def plot_tracking_dict(image, radius = max(5, int(im_w / 140.)) + if num_classes == 1: + start = records[-1].find('Total') + end = records[-1].find('In') + cv2.putText( + im, + records[-1][start:end - 2], (0, int(40 * text_scale)), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 0, 255), + thickness=2) + + if num_classes == 1 and do_entrance_counting: + entrance_line = tuple(map(int, entrance)) + cv2.rectangle( + im, + entrance_line[0:2], + entrance_line[2:4], + color=(0, 255, 255), + thickness=line_thickness) + # find start location for entrance counting data + start = records[-1].find('In') + cv2.putText( + im, + records[-1][start:-1], (0, int(60 * text_scale)), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 0, 255), + thickness=2) + for cls_id in range(num_classes): tlwhs = tlwhs_dict[cls_id] obj_ids = obj_ids_dict[cls_id] @@ -233,9 +260,9 @@ def plot_tracking_dict(image, obj_id = int(obj_ids[i]) if center_traj is not None: record_id.add(obj_id) - if obj_id not in center_traj: - center_traj[obj_id] = deque(maxlen=30) - center_traj[obj_id].append(center) + if obj_id not in center_traj[cls_id]: + center_traj[cls_id][obj_id] = deque(maxlen=30) + center_traj[cls_id][obj_id].append(center) id_text = '{}'.format(int(obj_id)) if ids2names != []: @@ -266,19 +293,11 @@ def plot_tracking_dict(image, cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=text_thickness) - if num_classes == 1 and do_entrance_counting: - entrance_line = tuple(map(int, entrance)) - cv2.rectangle( - im, - entrance_line[0:2], - entrance_line[2:4], - color=(0, 255, 255), - thickness=line_thickness) - - if center_traj is not None: - for i in center_traj.keys(): - if i not in record_id: - continue - for point in center_traj[i]: - cv2.circle(im, point, 3, (0, 0, 255), -1) + if center_traj is not None: + for traj in center_traj: + for i in traj.keys(): + if i not in record_id: + continue + for point in traj[i]: + cv2.circle(im, point, 3, (0, 0, 255), -1) return im