未验证 提交 e6216d50 编写于 作者: L lazyn1997 提交者: GitHub

fix video infer(#5107) (#5931)

* fix video infer(#5107)

* fix video infer(others)
上级 dba0d225
...@@ -329,7 +329,7 @@ class Detector(object): ...@@ -329,7 +329,7 @@ class Detector(object):
break break
print('detect frame: %d' % (index)) print('detect frame: %d' % (index))
index += 1 index += 1
results = self.predict_image([frame], visual=False) results = self.predict_image([frame[:, :, ::-1]], visual=False)
im = visualize_box_mask( im = visualize_box_mask(
frame, frame,
......
...@@ -266,7 +266,7 @@ class KeyPointDetector(Detector): ...@@ -266,7 +266,7 @@ class KeyPointDetector(Detector):
break break
print('detect frame: %d' % (index)) print('detect frame: %d' % (index))
index += 1 index += 1
results = self.predict_image([frame], visual=False) results = self.predict_image([frame[:, :, ::-1]], visual=False)
im_results = {} im_results = {}
im_results['keypoint'] = [results['keypoint'], results['score']] im_results['keypoint'] = [results['keypoint'], results['score']]
im = visualize_pose( im = visualize_pose(
......
...@@ -296,7 +296,7 @@ class JDE_Detector(Detector): ...@@ -296,7 +296,7 @@ class JDE_Detector(Detector):
timer.tic() timer.tic()
seq_name = video_out_name.split('.')[0] seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image( mot_results = self.predict_image(
[frame], visual=False, seq_name=seq_name) [frame[:, :, ::-1]], visual=False, seq_name=seq_name)
timer.toc() timer.toc()
online_tlwhs, online_scores, online_ids = mot_results[0] online_tlwhs, online_scores, online_ids = mot_results[0]
......
...@@ -167,7 +167,10 @@ def mot_topdown_unite_predict_video(mot_detector, ...@@ -167,7 +167,10 @@ def mot_topdown_unite_predict_video(mot_detector,
# mot model # mot model
timer_mot.tic() timer_mot.tic()
mot_results = mot_detector.predict_image([frame], visual=False)
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
mot_results = mot_detector.predict_image([frame2], visual=False)
timer_mot.toc() timer_mot.toc()
online_tlwhs, online_scores, online_ids = mot_results[0] online_tlwhs, online_scores, online_ids = mot_results[0]
results = convert_mot_to_det( results = convert_mot_to_det(
...@@ -179,7 +182,7 @@ def mot_topdown_unite_predict_video(mot_detector, ...@@ -179,7 +182,7 @@ def mot_topdown_unite_predict_video(mot_detector,
# keypoint model # keypoint model
timer_kp.tic() timer_kp.tic()
keypoint_res = predict_with_given_det( keypoint_res = predict_with_given_det(
frame, results, topdown_keypoint_detector, keypoint_batch_size, frame2, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.run_benchmark) FLAGS.run_benchmark)
timer_kp.toc() timer_kp.toc()
timer_mot_kp.toc() timer_mot_kp.toc()
......
...@@ -414,7 +414,7 @@ class SDE_Detector(Detector): ...@@ -414,7 +414,7 @@ class SDE_Detector(Detector):
timer.tic() timer.tic()
seq_name = video_out_name.split('.')[0] seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image( mot_results = self.predict_image(
[frame], visual=False, seq_name=seq_name) [frame[:, :, ::-1]], visual=False, seq_name=seq_name)
timer.toc() timer.toc()
# bs=1 in MOT model # bs=1 in MOT model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册