diff --git a/deploy/pptracking/python/mot_jde_infer.py b/deploy/pptracking/python/mot_jde_infer.py index 41011f5f291f0524f0a521e0f8297d90421381e5..857168dba167ea7f480400ea8adc9face0289dac 100644 --- a/deploy/pptracking/python/mot_jde_infer.py +++ b/deploy/pptracking/python/mot_jde_infer.py @@ -220,6 +220,11 @@ def predict_video(detector, camera_id): num_classes = detector.num_classes data_type = 'mcmot' if num_classes > 1 else 'mot' ids2names = detector.pred_config.labels + center_traj = None + entrance = None + records = None + if FLAGS.draw_center_traj: + center_traj = {} if num_classes == 1: id_set = set() @@ -231,6 +236,7 @@ def predict_video(detector, camera_id): entrance = [0, height / 2., width, height / 2.] video_fps = fps + while (1): ret, frame = capture.read() if not ret: @@ -260,10 +266,9 @@ def predict_video(detector, camera_id): prev_center = statistic['prev_center'] records = statistic['records'] - elif num_classes > 1 and do_entrance_counting: + elif num_classes > 1 and FLAGS.do_entrance_counting: raise NotImplementedError( 'Multi-class flow counting is not implemented now!') - im = plot_tracking_dict( frame, num_classes, @@ -274,7 +279,9 @@ def predict_video(detector, camera_id): fps=fps, ids2names=ids2names, do_entrance_counting=FLAGS.do_entrance_counting, - entrance=entrance) + entrance=entrance, + records=records, + center_traj=center_traj) if FLAGS.save_images: save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) diff --git a/deploy/pptracking/python/utils.py b/deploy/pptracking/python/utils.py index b3f2852e90f23d4ec8241c140f63c66e44cdabde..dfa8e687f61cdbe95f6c6b408d1ca320908080a5 100644 --- a/deploy/pptracking/python/utils.py +++ b/deploy/pptracking/python/utils.py @@ -131,6 +131,10 @@ def argsparser(): type=int, default=10, help="The seconds interval to count after tracking") + parser.add_argument( + "--draw_center_traj", + action='store_true', + help="Whether drawing the trajectory of center") return parser diff --git a/deploy/pptracking/python/visualize.py b/deploy/pptracking/python/visualize.py index 86d9fc5991391fcced5e0d5530a9d886e52ac86c..22db57a941b9176efae78a776d163a8fc3e9c207 100644 --- a/deploy/pptracking/python/visualize.py +++ b/deploy/pptracking/python/visualize.py @@ -20,6 +20,7 @@ import cv2 import numpy as np from PIL import Image, ImageDraw import math +from collections import deque def visualize_box_mask(im, results, labels, threshold=0.5): @@ -198,7 +199,9 @@ def plot_tracking_dict(image, fps=0., ids2names=[], do_entrance_counting=False, - entrance=None): + entrance=None, + records=None, + center_traj=None): im = np.ascontiguousarray(np.copy(image)) im_h, im_w = im.shape[:2] @@ -222,10 +225,17 @@ def plot_tracking_dict(image, text_scale, (0, 0, 255), thickness=2) + record_id = set() for i, tlwh in enumerate(tlwhs): x1, y1, w, h = tlwh intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) + center = tuple(map(int, (x1 + w / 2., y1 + h / 2.))) obj_id = int(obj_ids[i]) + if center_traj is not None: + record_id.add(obj_id) + if obj_id not in center_traj: + center_traj[obj_id] = deque(maxlen=30) + center_traj[obj_id].append(center) id_text = '{}'.format(int(obj_id)) if ids2names != []: @@ -264,4 +274,11 @@ def plot_tracking_dict(image, entrance_line[2:4], color=(0, 255, 255), thickness=line_thickness) + + if center_traj is not None: + for i in center_traj.keys(): + if i not in record_id: + continue + for point in center_traj[i]: + cv2.circle(im, point, 3, (0, 0, 255), -1) return im