未验证 提交 0b763b3d 编写于 作者: G George Ni 提交者: GitHub

[MOT] decord video infer (#3251)

* decord video infer

* add video infer write results
上级 51d311d3
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
import os import os
import cv2
import numpy as np import numpy as np
import decord as de
from collections import OrderedDict from collections import OrderedDict
try: try:
from collections.abc import Sequence from collections.abc import Sequence
...@@ -317,26 +317,23 @@ class MOTVideoDataset(DetDataset): ...@@ -317,26 +317,23 @@ class MOTVideoDataset(DetDataset):
self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.keep_ori_im = keep_ori_im self.keep_ori_im = keep_ori_im
self.roidbs = None self.roidbs = None
self.frame_rate = 25
def parse_dataset(self, ): def parse_dataset(self, ):
if not self.roidbs: if not self.roidbs:
self.roidbs = self._load_video_images() self.roidbs = self._load_video_images()
def _load_video_images(self): def _load_video_images(self):
self.cap = cv2.VideoCapture(self.video_file) self.video_frames = de.VideoReader(self.video_file)
self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) self.video_length = len(self.video_frames)
self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS))) logger.info('Length of the video: {:d} frames.'.format(
logger.info('Length of the video: {:d} frames'.format(self.vn)) self.video_length))
res = True
ct = 0
records = [] records = []
while res: for idx in range(self.video_length):
res, img = self.cap.read() image = self.video_frames.get_batch([idx]).asnumpy()[0]
image = np.ascontiguousarray(img, dtype=np.float32)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
im_shape = image.shape im_shape = image.shape
rec = { rec = {
'im_id': np.array([ct]), 'im_id': np.array([idx]),
'image': image, 'image': image,
'h': im_shape[0], 'h': im_shape[0],
'w': im_shape[1], 'w': im_shape[1],
...@@ -347,10 +344,8 @@ class MOTVideoDataset(DetDataset): ...@@ -347,10 +344,8 @@ class MOTVideoDataset(DetDataset):
} }
if self.keep_ori_im: if self.keep_ori_im:
rec.update({'ori_image': image}) rec.update({'ori_image': image})
ct += 1
records.append(rec) records.append(rec)
records = records[:-1] assert len(records) > 0, "No image file found."
assert len(records) > 0, "No image file found"
return records return records
def set_video(self, video_file): def set_video(self, video_file):
......
...@@ -372,6 +372,8 @@ class Tracker(object): ...@@ -372,6 +372,8 @@ class Tracker(object):
else: else:
raise ValueError(model_type) raise ValueError(model_type)
self.write_mot_results(result_filename, results, data_type)
if save_videos: if save_videos:
output_video_path = os.path.join(save_dir, '..', output_video_path = os.path.join(save_dir, '..',
'{}_vis.mp4'.format(seq)) '{}_vis.mp4'.format(seq))
......
...@@ -14,3 +14,4 @@ lap ...@@ -14,3 +14,4 @@ lap
sklearn sklearn
motmetrics motmetrics
openpyxl openpyxl
decord
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册