未验证 提交 8e13eedb 编写于 作者: W wangguanzhong 提交者: GitHub

[MOT] update window timer for pptracking (#4527)

* update window timer for pptracking

* keep origin mp4v format
上级 94c19353
......@@ -17,6 +17,7 @@ import cv2
import time
import paddle
import numpy as np
import collections
__all__ = [
'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results',
......@@ -29,13 +30,11 @@ class MOTTimer(object):
This class used to compute and print the current FPS while evaling.
"""
def __init__(self):
self.total_time = 0.
self.calls = 0
def __init__(self, window_size=20):
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
self.deque = collections.deque(maxlen=window_size)
def tic(self):
# using time.time instead of time.clock because time time.clock
......@@ -44,21 +43,16 @@ class MOTTimer(object):
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
self.deque.append(self.diff)
if average:
self.duration = self.average_time
self.duration = np.mean(self.deque)
else:
self.duration = self.diff
self.duration = np.sum(self.deque)
return self.duration
def clear(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
......
......@@ -211,7 +211,8 @@ def predict_video(detector, camera_id):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
......@@ -243,6 +244,7 @@ def predict_video(detector, camera_id):
results[cls_id].append((frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
fps = 1. / timer.duration
# NOTE: just implement flow statistic for one class
if num_classes == 1:
result = (frame_id + 1, online_tlwhs[0], online_scores[0],
......@@ -262,7 +264,6 @@ def predict_video(detector, camera_id):
raise NotImplementedError(
'Multi-class flow counting is not implemented now!')
fps = 1. / timer.average_time
im = plot_tracking_dict(
frame,
num_classes,
......@@ -282,7 +283,7 @@ def predict_video(detector, camera_id):
writer.write(im)
frame_id += 1
print('detect frame: %d' % (frame_id))
print('detect frame: %d, fps: %f' % (frame_id, fps))
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
......
......@@ -539,7 +539,8 @@ def predict_video(detector, reid_model, camera_id):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
......@@ -566,7 +567,7 @@ def predict_video(detector, reid_model, camera_id):
(frame_id + 1, online_tlwhs, online_scores, online_ids))
timer.toc()
fps = 1. / timer.average_time
fps = 1. / timer.duration
im = plot_tracking(
frame,
online_tlwhs,
......@@ -585,7 +586,7 @@ def predict_video(detector, reid_model, camera_id):
writer.write(im)
frame_id += 1
print('detect frame:%d' % (frame_id))
print('detect frame:%d, fps: %f' % (frame_id, fps))
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册