未验证 提交 fce66be9 编写于 作者: W wangguanzhong 提交者: GitHub

add draw_center_traj (#4567)

* add draw_center_traj
上级 7826f247
...@@ -220,6 +220,11 @@ def predict_video(detector, camera_id): ...@@ -220,6 +220,11 @@ def predict_video(detector, camera_id):
num_classes = detector.num_classes num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot' data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels ids2names = detector.pred_config.labels
center_traj = None
entrance = None
records = None
if FLAGS.draw_center_traj:
center_traj = {}
if num_classes == 1: if num_classes == 1:
id_set = set() id_set = set()
...@@ -231,6 +236,7 @@ def predict_video(detector, camera_id): ...@@ -231,6 +236,7 @@ def predict_video(detector, camera_id):
entrance = [0, height / 2., width, height / 2.] entrance = [0, height / 2., width, height / 2.]
video_fps = fps video_fps = fps
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
...@@ -260,10 +266,9 @@ def predict_video(detector, camera_id): ...@@ -260,10 +266,9 @@ def predict_video(detector, camera_id):
prev_center = statistic['prev_center'] prev_center = statistic['prev_center']
records = statistic['records'] records = statistic['records']
elif num_classes > 1 and do_entrance_counting: elif num_classes > 1 and FLAGS.do_entrance_counting:
raise NotImplementedError( raise NotImplementedError(
'Multi-class flow counting is not implemented now!') 'Multi-class flow counting is not implemented now!')
im = plot_tracking_dict( im = plot_tracking_dict(
frame, frame,
num_classes, num_classes,
...@@ -274,7 +279,9 @@ def predict_video(detector, camera_id): ...@@ -274,7 +279,9 @@ def predict_video(detector, camera_id):
fps=fps, fps=fps,
ids2names=ids2names, ids2names=ids2names,
do_entrance_counting=FLAGS.do_entrance_counting, do_entrance_counting=FLAGS.do_entrance_counting,
entrance=entrance) entrance=entrance,
records=records,
center_traj=center_traj)
if FLAGS.save_images: if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
......
...@@ -131,6 +131,10 @@ def argsparser(): ...@@ -131,6 +131,10 @@ def argsparser():
type=int, type=int,
default=10, default=10,
help="The seconds interval to count after tracking") help="The seconds interval to count after tracking")
parser.add_argument(
"--draw_center_traj",
action='store_true',
help="Whether drawing the trajectory of center")
return parser return parser
......
...@@ -20,6 +20,7 @@ import cv2 ...@@ -20,6 +20,7 @@ import cv2
import numpy as np import numpy as np
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
import math import math
from collections import deque
def visualize_box_mask(im, results, labels, threshold=0.5): def visualize_box_mask(im, results, labels, threshold=0.5):
...@@ -198,7 +199,9 @@ def plot_tracking_dict(image, ...@@ -198,7 +199,9 @@ def plot_tracking_dict(image,
fps=0., fps=0.,
ids2names=[], ids2names=[],
do_entrance_counting=False, do_entrance_counting=False,
entrance=None): entrance=None,
records=None,
center_traj=None):
im = np.ascontiguousarray(np.copy(image)) im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2] im_h, im_w = im.shape[:2]
...@@ -222,10 +225,17 @@ def plot_tracking_dict(image, ...@@ -222,10 +225,17 @@ def plot_tracking_dict(image,
text_scale, (0, 0, 255), text_scale, (0, 0, 255),
thickness=2) thickness=2)
record_id = set()
for i, tlwh in enumerate(tlwhs): for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
center = tuple(map(int, (x1 + w / 2., y1 + h / 2.)))
obj_id = int(obj_ids[i]) obj_id = int(obj_ids[i])
if center_traj is not None:
record_id.add(obj_id)
if obj_id not in center_traj:
center_traj[obj_id] = deque(maxlen=30)
center_traj[obj_id].append(center)
id_text = '{}'.format(int(obj_id)) id_text = '{}'.format(int(obj_id))
if ids2names != []: if ids2names != []:
...@@ -264,4 +274,11 @@ def plot_tracking_dict(image, ...@@ -264,4 +274,11 @@ def plot_tracking_dict(image,
entrance_line[2:4], entrance_line[2:4],
color=(0, 255, 255), color=(0, 255, 255),
thickness=line_thickness) thickness=line_thickness)
if center_traj is not None:
for i in center_traj.keys():
if i not in record_id:
continue
for point in center_traj[i]:
cv2.circle(im, point, 3, (0, 0, 255), -1)
return im return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册