未验证 提交 8e13eedb 编写于 作者: W wangguanzhong 提交者: GitHub

[MOT] update window timer for pptracking (#4527)

* update window timer for pptracking

* keep origin mp4v format
上级 94c19353
...@@ -17,6 +17,7 @@ import cv2 ...@@ -17,6 +17,7 @@ import cv2
import time import time
import paddle import paddle
import numpy as np import numpy as np
import collections
__all__ = [ __all__ = [
'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results', 'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results',
...@@ -29,13 +30,11 @@ class MOTTimer(object): ...@@ -29,13 +30,11 @@ class MOTTimer(object):
This class used to compute and print the current FPS while evaling. This class used to compute and print the current FPS while evaling.
""" """
def __init__(self): def __init__(self, window_size=20):
self.total_time = 0.
self.calls = 0
self.start_time = 0. self.start_time = 0.
self.diff = 0. self.diff = 0.
self.average_time = 0.
self.duration = 0. self.duration = 0.
self.deque = collections.deque(maxlen=window_size)
def tic(self): def tic(self):
# using time.time instead of time.clock because time time.clock # using time.time instead of time.clock because time time.clock
...@@ -44,21 +43,16 @@ class MOTTimer(object): ...@@ -44,21 +43,16 @@ class MOTTimer(object):
def toc(self, average=True): def toc(self, average=True):
self.diff = time.time() - self.start_time self.diff = time.time() - self.start_time
self.total_time += self.diff self.deque.append(self.diff)
self.calls += 1
self.average_time = self.total_time / self.calls
if average: if average:
self.duration = self.average_time self.duration = np.mean(self.deque)
else: else:
self.duration = self.diff self.duration = np.sum(self.deque)
return self.duration return self.duration
def clear(self): def clear(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0. self.start_time = 0.
self.diff = 0. self.diff = 0.
self.average_time = 0.
self.duration = 0. self.duration = 0.
......
...@@ -211,7 +211,8 @@ def predict_video(detector, camera_id): ...@@ -211,7 +211,8 @@ def predict_video(detector, camera_id):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images: if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer = MOTTimer() timer = MOTTimer()
...@@ -243,6 +244,7 @@ def predict_video(detector, camera_id): ...@@ -243,6 +244,7 @@ def predict_video(detector, camera_id):
results[cls_id].append((frame_id + 1, online_tlwhs[cls_id], results[cls_id].append((frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id])) online_scores[cls_id], online_ids[cls_id]))
fps = 1. / timer.duration
# NOTE: just implement flow statistic for one class # NOTE: just implement flow statistic for one class
if num_classes == 1: if num_classes == 1:
result = (frame_id + 1, online_tlwhs[0], online_scores[0], result = (frame_id + 1, online_tlwhs[0], online_scores[0],
...@@ -262,7 +264,6 @@ def predict_video(detector, camera_id): ...@@ -262,7 +264,6 @@ def predict_video(detector, camera_id):
raise NotImplementedError( raise NotImplementedError(
'Multi-class flow counting is not implemented now!') 'Multi-class flow counting is not implemented now!')
fps = 1. / timer.average_time
im = plot_tracking_dict( im = plot_tracking_dict(
frame, frame,
num_classes, num_classes,
...@@ -282,7 +283,7 @@ def predict_video(detector, camera_id): ...@@ -282,7 +283,7 @@ def predict_video(detector, camera_id):
writer.write(im) writer.write(im)
frame_id += 1 frame_id += 1
print('detect frame: %d' % (frame_id)) print('detect frame: %d, fps: %f' % (frame_id, fps))
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking Detection', im) cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
......
...@@ -539,7 +539,8 @@ def predict_video(detector, reid_model, camera_id): ...@@ -539,7 +539,8 @@ def predict_video(detector, reid_model, camera_id):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images: if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer = MOTTimer() timer = MOTTimer()
...@@ -566,7 +567,7 @@ def predict_video(detector, reid_model, camera_id): ...@@ -566,7 +567,7 @@ def predict_video(detector, reid_model, camera_id):
(frame_id + 1, online_tlwhs, online_scores, online_ids)) (frame_id + 1, online_tlwhs, online_scores, online_ids))
timer.toc() timer.toc()
fps = 1. / timer.average_time fps = 1. / timer.duration
im = plot_tracking( im = plot_tracking(
frame, frame,
online_tlwhs, online_tlwhs,
...@@ -585,7 +586,7 @@ def predict_video(detector, reid_model, camera_id): ...@@ -585,7 +586,7 @@ def predict_video(detector, reid_model, camera_id):
writer.write(im) writer.write(im)
frame_id += 1 frame_id += 1
print('detect frame:%d' % (frame_id)) print('detect frame:%d, fps: %f' % (frame_id, fps))
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking Detection', im) cv2.imshow('Tracking Detection', im)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册