未验证 提交 4e517475 编写于 作者: W wangguanzhong 提交者: GitHub

support draw traj in mct (#4604)

上级 d6dff40e
......@@ -224,7 +224,7 @@ def predict_video(detector, camera_id):
entrance = None
records = None
if FLAGS.draw_center_traj:
center_traj = {}
center_traj = [{} for i in range(num_classes)]
if num_classes == 1:
id_set = set()
......
......@@ -129,7 +129,7 @@ def argsparser():
parser.add_argument(
"--secs_interval",
type=int,
default=10,
default=2,
help="The seconds interval to count after tracking")
parser.add_argument(
"--draw_center_traj",
......
......@@ -213,6 +213,33 @@ def plot_tracking_dict(image,
radius = max(5, int(im_w / 140.))
if num_classes == 1:
start = records[-1].find('Total')
end = records[-1].find('In')
cv2.putText(
im,
records[-1][start:end - 2], (0, int(40 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
if num_classes == 1 and do_entrance_counting:
entrance_line = tuple(map(int, entrance))
cv2.rectangle(
im,
entrance_line[0:2],
entrance_line[2:4],
color=(0, 255, 255),
thickness=line_thickness)
# find start location for entrance counting data
start = records[-1].find('In')
cv2.putText(
im,
records[-1][start:-1], (0, int(60 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for cls_id in range(num_classes):
tlwhs = tlwhs_dict[cls_id]
obj_ids = obj_ids_dict[cls_id]
......@@ -233,9 +260,9 @@ def plot_tracking_dict(image,
obj_id = int(obj_ids[i])
if center_traj is not None:
record_id.add(obj_id)
if obj_id not in center_traj:
center_traj[obj_id] = deque(maxlen=30)
center_traj[obj_id].append(center)
if obj_id not in center_traj[cls_id]:
center_traj[cls_id][obj_id] = deque(maxlen=30)
center_traj[cls_id][obj_id].append(center)
id_text = '{}'.format(int(obj_id))
if ids2names != []:
......@@ -266,19 +293,11 @@ def plot_tracking_dict(image,
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
if num_classes == 1 and do_entrance_counting:
entrance_line = tuple(map(int, entrance))
cv2.rectangle(
im,
entrance_line[0:2],
entrance_line[2:4],
color=(0, 255, 255),
thickness=line_thickness)
if center_traj is not None:
for i in center_traj.keys():
if i not in record_id:
continue
for point in center_traj[i]:
cv2.circle(im, point, 3, (0, 0, 255), -1)
if center_traj is not None:
for traj in center_traj:
for i in traj.keys():
if i not in record_id:
continue
for point in traj[i]:
cv2.circle(im, point, 3, (0, 0, 255), -1)
return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册