未验证 提交 fce66be9 编写于 作者: W wangguanzhong 提交者: GitHub

add draw_center_traj (#4567)

* add draw_center_traj
上级 7826f247
......@@ -220,6 +220,11 @@ def predict_video(detector, camera_id):
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels
center_traj = None
entrance = None
records = None
if FLAGS.draw_center_traj:
center_traj = {}
if num_classes == 1:
id_set = set()
......@@ -231,6 +236,7 @@ def predict_video(detector, camera_id):
entrance = [0, height / 2., width, height / 2.]
video_fps = fps
while (1):
ret, frame = capture.read()
if not ret:
......@@ -260,10 +266,9 @@ def predict_video(detector, camera_id):
prev_center = statistic['prev_center']
records = statistic['records']
elif num_classes > 1 and do_entrance_counting:
elif num_classes > 1 and FLAGS.do_entrance_counting:
raise NotImplementedError(
'Multi-class flow counting is not implemented now!')
im = plot_tracking_dict(
frame,
num_classes,
......@@ -274,7 +279,9 @@ def predict_video(detector, camera_id):
fps=fps,
ids2names=ids2names,
do_entrance_counting=FLAGS.do_entrance_counting,
entrance=entrance)
entrance=entrance,
records=records,
center_traj=center_traj)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
......
......@@ -131,6 +131,10 @@ def argsparser():
type=int,
default=10,
help="The seconds interval to count after tracking")
parser.add_argument(
"--draw_center_traj",
action='store_true',
help="Whether drawing the trajectory of center")
return parser
......
......@@ -20,6 +20,7 @@ import cv2
import numpy as np
from PIL import Image, ImageDraw
import math
from collections import deque
def visualize_box_mask(im, results, labels, threshold=0.5):
......@@ -198,7 +199,9 @@ def plot_tracking_dict(image,
fps=0.,
ids2names=[],
do_entrance_counting=False,
entrance=None):
entrance=None,
records=None,
center_traj=None):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
......@@ -222,10 +225,17 @@ def plot_tracking_dict(image,
text_scale, (0, 0, 255),
thickness=2)
record_id = set()
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
center = tuple(map(int, (x1 + w / 2., y1 + h / 2.)))
obj_id = int(obj_ids[i])
if center_traj is not None:
record_id.add(obj_id)
if obj_id not in center_traj:
center_traj[obj_id] = deque(maxlen=30)
center_traj[obj_id].append(center)
id_text = '{}'.format(int(obj_id))
if ids2names != []:
......@@ -264,4 +274,11 @@ def plot_tracking_dict(image,
entrance_line[2:4],
color=(0, 255, 255),
thickness=line_thickness)
if center_traj is not None:
for i in center_traj.keys():
if i not in record_id:
continue
for point in center_traj[i]:
cv2.circle(im, point, 3, (0, 0, 255), -1)
return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册