未验证 提交 0c2b74f9 编写于 作者: G George Ni 提交者: GitHub

[MOT] Fix mot video decode (#3681)

* mot video decode to images

* add infer_dir for mot

* fix export and unite pose infer

* remove task in mot source

* set image_dir

* update doc
上级 320c6eea
metric: MOT
num_classes: 1
MOTDataZoo: {
'MOT15_train': ['ADL-Rundle-6', 'ADL-Rundle-8', 'ETH-Bahnhof', 'ETH-Pedcross2', 'ETH-Sunnyday', 'KITTI-13', 'KITTI-17', 'PETS09-S2L1', 'TUD-Campus', 'TUD-Stadtmitte', 'Venice-2'],
'MOT15_test': ['ADL-Rundle-1', 'ADL-Rundle-3', 'AVG-TownCentre', 'ETH-Crossing', 'ETH-Jelmoli', 'ETH-Linthescher', 'KITTI-16', 'KITTI-19', 'PETS09-S2L2', 'TUD-Crossing', 'Venice-1'],
'MOT16_train': ['MOT16-02', 'MOT16-04', 'MOT16-05', 'MOT16-09', 'MOT16-10', 'MOT16-11', 'MOT16-13'],
'MOT16_test': ['MOT16-01', 'MOT16-03', 'MOT16-06', 'MOT16-07', 'MOT16-08', 'MOT16-12', 'MOT16-14'],
'MOT17_train': ['MOT17-02-SDP', 'MOT17-04-SDP', 'MOT17-05-SDP', 'MOT17-09-SDP', 'MOT17-10-SDP', 'MOT17-11-SDP', 'MOT17-13-SDP'],
'MOT17_test': ['MOT17-01-SDP', 'MOT17-03-SDP', 'MOT17-06-SDP', 'MOT17-07-SDP', 'MOT17-08-SDP', 'MOT17-12-SDP', 'MOT17-14-SDP'],
'MOT20_train': ['MOT20-01', 'MOT20-02', 'MOT20-03', 'MOT20-05'],
'MOT20_test': ['MOT20-04', 'MOT20-06', 'MOT20-07', 'MOT20-08'],
'demo': ['MOT16-02'],
}
# for MOT training
TrainDataset:
!MOTDataSet
......@@ -21,16 +9,15 @@ TrainDataset:
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
# for MOT evaluation
# If you want to change the MOT evaluation dataset, please modify 'task' and 'data_root'
# If you want to change the MOT evaluation dataset, please modify 'data_root'
EvalMOTDataset:
!MOTImageFolder
task: MOT16_train
dataset_dir: dataset/mot
data_root: MOT16/images/train
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
# for MOT video inference
TestMOTDataset:
!MOTVideoDataset
!MOTImageFolder
dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video
......@@ -224,11 +224,10 @@ CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_d
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=output/fairmot_dla34_30e_1088x608/model_final.pdparams
```
**Notes:**
The default evaluation dataset is MOT-16 Train Set. If you want to change the evaluation dataset, please refer to the following code and modify `configs/datasets/mot.yml`
The default evaluation dataset is MOT-16 Train Set. If you want to change the evaluation dataset, please refer to the following code and modify `configs/datasets/mot.yml`, modify `data_root`
```
EvalMOTDataset:
!MOTImageFolder
task: MOT17_train
dataset_dir: dataset/mot
data_root: MOT17/images/train
keep_ori_im: False # set True if save visualization images or video
......@@ -242,6 +241,14 @@ Inference a vidoe on single GPU with following command:
# inference on video and save a video
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams --video_file={your video name}.mp4 --save_videos
```
Inference a image folder on single GPU with following command:
```bash
# inference image folder and save a video
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams --image_dir={your infer images folder} --save_videos
```
**Notes:**
Please make sure that [ffmpeg](https://ffmpeg.org/ffmpeg.html) is installed first, on Linux(Ubuntu) platform you can directly install it by the following command:`apt-get update && apt-get install -y ffmpeg`.
......
......@@ -222,11 +222,10 @@ CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_d
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=output/fairmot_dla34_30e_1088x608/model_final.pdparams
```
**注意:**
默认评估的是MOT-16 Train Set数据集, 如需换评估数据集可参照以下代码修改`configs/datasets/mot.yml`
默认评估的是MOT-16 Train Set数据集,如需换评估数据集可参照以下代码修改`configs/datasets/mot.yml`,修改`data_root`
```
EvalMOTDataset:
!MOTImageFolder
task: MOT17_train
dataset_dir: dataset/mot
data_root: MOT17/images/train
keep_ori_im: False # set True if save visualization images or video
......@@ -241,6 +240,13 @@ EvalMOTDataset:
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams --video_file={your video name}.mp4 --save_videos
```
使用单个GPU通过如下命令预测一个图片文件夹,并保存为视频
```bash
# 预测一个图片文件夹
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams --image_dir={your infer images folder} --save_videos
```
**注意:**
请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`
......
......@@ -11,6 +11,7 @@ TestMOTReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
......
......@@ -7,7 +7,6 @@ _BASE_: [
EvalMOTDataset:
!MOTImageFolder
task: MOT16_train
dataset_dir: dataset/mot
data_root: MOT16/images/train
keep_ori_im: True # set as True in DeepSORT
......
......@@ -7,7 +7,6 @@ _BASE_: [
EvalMOTDataset:
!MOTImageFolder
task: MOT16_train
dataset_dir: dataset/mot
data_root: MOT16/images/train
keep_ori_im: True # set as True in DeepSORT
......
......@@ -59,7 +59,6 @@ CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_d
```
EvalMOTDataset:
!MOTImageFolder
task: MOT17_train
dataset_dir: dataset/mot
data_root: MOT17/images/train
keep_ori_im: False # set True if save visualization images or video
......
......@@ -57,7 +57,6 @@ CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/fairmot_d
```
EvalMOTDataset:
!MOTImageFolder
task: MOT17_train
dataset_dir: dataset/mot
data_root: MOT17/images/train
keep_ori_im: False # set True if save visualization images or video
......
......@@ -22,8 +22,6 @@ TrainReader:
EvalMOTReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
......@@ -36,6 +34,7 @@ TestMOTReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]}
- Permute: {}
......
......@@ -65,7 +65,6 @@ CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/jde/jde_darknet53
```
EvalMOTDataset:
!MOTImageFolder
task: MOT17_train
dataset_dir: dataset/mot
data_root: MOT17/images/train
keep_ori_im: False # set True if save visualization images or video
......
......@@ -66,7 +66,6 @@ CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/jde/jde_darknet53
```
EvalMOTDataset:
!MOTImageFolder
task: MOT17_train
dataset_dir: dataset/mot
data_root: MOT17/images/train
keep_ori_im: False # set True if save visualization images or video
......
......@@ -41,6 +41,7 @@ TestMOTReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
......
......@@ -41,6 +41,7 @@ TestMOTReader:
inputs_def:
image_shape: [3, 320, 576]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [320, 576]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
......
......@@ -41,6 +41,7 @@ TestMOTReader:
inputs_def:
image_shape: [3, 480, 864]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [480, 864]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
......
......@@ -84,7 +84,10 @@ class JDE_Detector(Detector):
conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0.
tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7
metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean'
self.tracker = JDETracker(conf_thres=conf_thres, tracked_thresh=tracked_thresh, metric_type=metric_type)
self.tracker = JDETracker(
conf_thres=conf_thres,
tracked_thresh=tracked_thresh,
metric_type=metric_type)
def postprocess(self, pred_dets, pred_embs, threshold):
online_targets = self.tracker.update(pred_dets, pred_embs)
......
......@@ -178,7 +178,9 @@ def mot_keypoint_unite_predict_video(mot_model,
keypoint_results,
visual_thread=FLAGS.keypoint_threshold,
returnimg=True,
ids=online_ids)
ids=online_ids
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' else
None)
online_im = mot_vis.plot_tracking(
im,
......
......@@ -13,8 +13,10 @@
# limitations under the License.
import os
import sys
import cv2
import glob
import numpy as np
import decord as de
from collections import OrderedDict
try:
from collections.abc import Sequence
......@@ -228,8 +230,18 @@ def mot_label():
@register
@serializable
class MOTImageFolder(DetDataset):
"""
Load MOT dataset with MOT format from image folder or video .
Args:
video_file (str): path of the video file, default ''.
dataset_dir (str): root directory for dataset.
keep_ori_im (bool): whether to keep original image, default False.
Set True when used during MOT model inference while saving
images or video, or used in DeepSORT.
"""
def __init__(self,
task,
video_file=None,
dataset_dir=None,
data_root=None,
image_dir=None,
......@@ -238,20 +250,53 @@ class MOTImageFolder(DetDataset):
**kwargs):
super(MOTImageFolder, self).__init__(
dataset_dir, image_dir, sample_num=sample_num)
self.task = task
self.video_file = video_file
self.data_root = data_root
self.keep_ori_im = keep_ori_im
self._imid2path = {}
self.roidbs = None
self.frame_rate = 30
def check_or_download_dataset(self):
return
def parse_dataset(self, ):
if not self.roidbs:
self.roidbs = self._load_images()
if self.video_file is None:
self.roidbs = self._load_images()
else:
self.roidbs = self._load_video_images()
def _load_video_images(self):
cap = cv2.VideoCapture(self.video_file)
self.frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
extension = self.video_file.split('.')[-1]
output_path = self.video_file.replace('.{}'.format(extension), '')
frames_path = video2frames(self.video_file, output_path)
self.video_frames = sorted(
glob.glob(os.path.join(frames_path, '*.png')))
self.video_length = len(self.video_frames)
logger.info('Length of the video: {:d} frames.'.format(
self.video_length))
ct = 0
records = []
for image in self.video_frames:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
if self.keep_ori_im:
rec.update({'keep_ori_im': 1})
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def _parse(self):
def _find_images(self):
image_dir = self.image_dir
if not isinstance(image_dir, Sequence):
image_dir = [image_dir]
......@@ -265,7 +310,7 @@ class MOTImageFolder(DetDataset):
return images
def _load_images(self):
images = self._parse()
images = self._find_images()
ct = 0
records = []
for image in images:
......@@ -289,67 +334,44 @@ class MOTImageFolder(DetDataset):
self.image_dir = images
self.roidbs = self._load_images()
def set_video(self, video_file):
self.video_file = video_file
assert os.path.isfile(self.video_file) and _is_valid_video(self.video_file), \
"wrong or unsupported file format: {}".format(self.video_file)
self.roidbs = self._load_video_images()
def _is_valid_video(f, extensions=('.mp4', '.avi', '.mov', '.rmvb', 'flv')):
return f.lower().endswith(extensions)
@register
@serializable
class MOTVideoDataset(DetDataset):
"""
Load MOT dataset with MOT format from video for inference.
Args:
video_file (str): path of the video file
dataset_dir (str): root directory for dataset.
keep_ori_im (bool): whether to keep original image, default False.
Set True when used during MOT model inference while saving
images or video, or used in DeepSORT.
"""
def video2frames(video_path, outpath, **kargs):
def _dict2str(kargs):
cmd_str = ''
for k, v in kargs.items():
cmd_str += (' ' + str(k) + ' ' + str(v))
return cmd_str
def __init__(self,
video_file='',
dataset_dir=None,
keep_ori_im=False,
**kwargs):
super(MOTVideoDataset, self).__init__(dataset_dir=dataset_dir)
self.video_file = video_file
self.dataset_dir = dataset_dir
self.keep_ori_im = keep_ori_im
self.roidbs = None
self.frame_rate = 25
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = os.path.basename(video_path).split('.')[0]
out_full_path = os.path.join(outpath, vid_name)
def parse_dataset(self, ):
if not self.roidbs:
self.roidbs = self._load_video_images()
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
def _load_video_images(self):
self.video_frames = de.VideoReader(self.video_file)
self.video_length = len(self.video_frames)
logger.info('Length of the video: {:d} frames.'.format(
self.video_length))
records = []
for idx in range(self.video_length):
image = self.video_frames.get_batch([idx]).asnumpy()[0]
im_shape = image.shape
rec = {
'im_id': np.array([idx]),
'image': image,
'h': im_shape[0],
'w': im_shape[1],
'im_shape': np.array(
im_shape[:2], dtype=np.float32),
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
}
if self.keep_ori_im:
rec.update({'ori_image': image})
records.append(rec)
assert len(records) > 0, "No image file found."
return records
# video file name
outformat = os.path.join(out_full_path, '%08d.png')
def set_video(self, video_file):
self.video_file = video_file
assert os.path.isfile(self.video_file) and _is_valid_video(self.video_file), \
"wrong or unsupported file format: {}".format(self.video_file)
self.roidbs = self._load_video_images()
cmd = ffmpeg
cmd = ffmpeg + [' -i ', video_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd) + _dict2str(kargs)
try:
os.system(cmd)
except:
raise RuntimeError('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
sys.exit(-1)
sys.stdout.flush()
return out_full_path
......@@ -58,9 +58,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
label_list = [str(cat) for cat in catid2name.values()]
sample_transforms = reader_cfg['sample_transforms']
if arch != 'mot_arch':
sample_transforms = sample_transforms[1:]
for st in sample_transforms:
for st in sample_transforms[1:]:
for key, value in st.items():
p = {'type': key}
if key == 'Resize':
......@@ -82,12 +80,14 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
return preprocess_list, label_list
def _parse_tracker(tracker_cfg):
tracker_params = {}
for k, v in tracker_cfg.items():
tracker_params.update({k: v})
return tracker_params
def _dump_infer_config(config, path, image_shape, model):
arch_state = False
from ppdet.core.config.yaml_helpers import setup_orderdict
......
......@@ -282,18 +282,25 @@ class Tracker(object):
n_frame = 0
timer_avgs, timer_calls = [], []
for seq in seqs:
if not os.path.isdir(os.path.join(data_root, seq)):
continue
infer_dir = os.path.join(data_root, seq, 'img1')
seqinfo = os.path.join(data_root, seq, 'seqinfo.ini')
if not os.path.exists(seqinfo) or not os.path.exists(
infer_dir) or not os.path.isdir(infer_dir):
continue
save_dir = os.path.join(output_dir, 'mot_outputs',
seq) if save_images or save_videos else None
logger.info('start seq: {}'.format(seq))
infer_dir = os.path.join(data_root, seq, 'img1')
images = self.get_infer_images(infer_dir)
self.dataset.set_images(images)
dataloader = create('EvalMOTReader')(self.dataset, 0)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
meta_info = open(seqinfo).read()
frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
meta_info.find('\nseqLength')])
with paddle.no_grad():
......@@ -365,6 +372,7 @@ class Tracker(object):
def mot_predict(self,
video_file,
image_dir,
output_dir,
data_type='mot',
model_type='JDE',
......@@ -373,6 +381,13 @@ class Tracker(object):
show_image=False,
det_results_dir='',
draw_threshold=0.5):
assert video_file is not None or image_dir is not None, \
"--video_file or --image_dir should be set."
assert video_file is None or os.path.isfile(video_file), \
"{} is not a file".format(video_file)
assert image_dir is None or os.path.isdir(image_dir), \
"{} is not a directory".format(image_dir)
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root)
......@@ -381,13 +396,26 @@ class Tracker(object):
assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
"model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
# run tracking
seq = video_file.split('/')[-1].split('.')[0]
# run tracking
if video_file:
seq = video_file.split('/')[-1].split('.')[0]
self.dataset.set_video(video_file)
logger.info('Starting tracking video {}'.format(video_file))
elif image_dir:
seq = image_dir.split('/')[-1].split('.')[0]
images = [
'{}/{}'.format(image_dir, x) for x in os.listdir(image_dir)
]
images.sort()
self.dataset.set_images(images)
logger.info('Starting tracking folder {}, found {} images'.format(
image_dir, len(images)))
else:
raise ValueError('--video_file or --image_dir should be set.')
save_dir = os.path.join(output_dir, 'mot_outputs',
seq) if save_images or save_videos else None
logger.info('Starting tracking {}'.format(video_file))
self.dataset.set_video(video_file)
dataloader = create('TestMOTReader')(self.dataset, 0)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
frame_rate = self.dataset.frame_rate
......
......@@ -73,11 +73,11 @@ def parse_args():
def run(FLAGS, cfg):
task = cfg['EvalMOTDataset'].task
dataset_dir = cfg['EvalMOTDataset'].dataset_dir
data_root = cfg['EvalMOTDataset'].data_root
data_root = '{}/{}'.format(dataset_dir, data_root)
seqs = cfg['MOTDataZoo'][task]
seqs = os.listdir(data_root)
seqs.sort()
# build Tracker
tracker = Tracker(cfg, mode='eval')
......
......@@ -43,6 +43,11 @@ def parse_args():
parser = ArgsParser()
parser.add_argument(
'--video_file', type=str, default=None, help='Video name for tracking.')
parser.add_argument(
"--image_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--data_type",
type=str,
......@@ -95,6 +100,7 @@ def run(FLAGS, cfg):
# inference
tracker.mot_predict(
video_file=FLAGS.video_file,
image_dir=FLAGS.image_dir,
data_type=FLAGS.data_type,
model_type=cfg.architecture,
output_dir=FLAGS.output_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册