未验证 提交 e4259d42 编写于 作者: G George Ni 提交者: GitHub

fix mot infer video (#3823)

上级 a3e2b2ea
...@@ -221,7 +221,8 @@ def predict_video(detector, camera_id): ...@@ -221,7 +221,8 @@ def predict_video(detector, camera_id):
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) if not FLAGS.save_images:
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer = MOTTimer() timer = MOTTimer()
results = [] results = []
...@@ -236,7 +237,7 @@ def predict_video(detector, camera_id): ...@@ -236,7 +237,7 @@ def predict_video(detector, camera_id):
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids)) results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
fps = 1. / timer.average_time fps = 1. / timer.average_time
online_im = mot_vis.plot_tracking( im = mot_vis.plot_tracking(
frame, frame,
online_tlwhs, online_tlwhs,
online_ids, online_ids,
...@@ -249,11 +250,11 @@ def predict_video(detector, camera_id): ...@@ -249,11 +250,11 @@ def predict_video(detector, camera_id):
os.makedirs(save_dir) os.makedirs(save_dir)
cv2.imwrite( cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im) im)
else:
writer.write(im)
frame_id += 1 frame_id += 1
print('detect frame:%d' % (frame_id)) print('detect frame:%d' % (frame_id))
im = np.array(online_im)
writer.write(im)
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking Detection', im) cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
...@@ -262,7 +263,15 @@ def predict_video(detector, camera_id): ...@@ -262,7 +263,15 @@ def predict_video(detector, camera_id):
result_filename = os.path.join(FLAGS.output_dir, result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt') video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results) write_mot_results(result_filename, results)
writer.release()
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
save_dir, out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()
def main(): def main():
......
...@@ -137,7 +137,8 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -137,7 +137,8 @@ def mot_keypoint_unite_predict_video(mot_model,
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) if not FLAGS.save_images:
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer_mot = FPSTimer() timer_mot = FPSTimer()
timer_kp = FPSTimer() timer_kp = FPSTimer()
...@@ -202,8 +203,8 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -202,8 +203,8 @@ def mot_keypoint_unite_predict_video(mot_model,
os.makedirs(save_dir) os.makedirs(save_dir)
cv2.imwrite( cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im) writer.write(im)
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking and keypoint results', im) cv2.imshow('Tracking and keypoint results', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
...@@ -212,7 +213,15 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -212,7 +213,15 @@ def mot_keypoint_unite_predict_video(mot_model,
result_filename = os.path.join(FLAGS.output_dir, result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt') video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, mot_results) write_mot_results(result_filename, mot_results)
writer.release()
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
save_dir, out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()
def main(): def main():
......
...@@ -356,7 +356,8 @@ def predict_video(detector, reid_model, camera_id): ...@@ -356,7 +356,8 @@ def predict_video(detector, reid_model, camera_id):
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) if not FLAGS.save_images:
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer = MOTTimer() timer = MOTTimer()
results = [] results = []
...@@ -379,7 +380,7 @@ def predict_video(detector, reid_model, camera_id): ...@@ -379,7 +380,7 @@ def predict_video(detector, reid_model, camera_id):
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids)) results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
fps = 1. / timer.average_time fps = 1. / timer.average_time
online_im = mot_vis.plot_tracking( im = mot_vis.plot_tracking(
frame, frame,
online_tlwhs, online_tlwhs,
online_ids, online_ids,
...@@ -392,11 +393,11 @@ def predict_video(detector, reid_model, camera_id): ...@@ -392,11 +393,11 @@ def predict_video(detector, reid_model, camera_id):
os.makedirs(save_dir) os.makedirs(save_dir)
cv2.imwrite( cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im) im)
else:
writer.write(im)
frame_id += 1 frame_id += 1
print('detect frame:%d' % (frame_id)) print('detect frame:%d' % (frame_id))
im = np.array(online_im)
writer.write(im)
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking Detection', im) cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
...@@ -405,7 +406,15 @@ def predict_video(detector, reid_model, camera_id): ...@@ -405,7 +406,15 @@ def predict_video(detector, reid_model, camera_id):
result_filename = os.path.join(FLAGS.output_dir, result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt') video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results) write_mot_results(result_filename, results)
writer.release()
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
save_dir, out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册