未验证 提交 22869545 编写于 作者: C chenjian 提交者: GitHub

Add jde module (#1587)

上级 dedf4bfe
# jde_darknet53
|模型名称|jde_darknet53|
| :--- | :---: |
|类别|视频 - 多目标追踪|
|网络|YOLOv3|
|数据集|Caltech Pedestrian+CityPersons+CUHK-SYSU+PRW+ETHZ+MOT17|
|是否支持Fine-tuning|否|
|模型大小|420MB|
|最新更新日期|2021-08-26|
|数据指标|-|
## 一、模型基本信息
- ### 应用效果展示
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/131989578-ec06e18f-e122-40b0-84d2-8772fd35391a.gif" hspace='10'/> <br />
</p>
- ### 模型介绍
- JDE(Joint Detection and Embedding)是在一个单一的共享神经网络中同时学习目标检测任务和embedding任务,并同时输出检测结果和对应的外观embedding匹配的算法。JDE原论文是基于Anchor Base的YOLOv3检测器新增加一个ReID分支学习embedding,训练过程被构建为一个多任务联合学习问题,兼顾精度和速度。
- 更多详情参考:[Towards Real-Time Multi-Object Tracking](https://arxiv.org/abs/1909.12605)
## 二、安装
- ### 1、环境依赖
- ppdet >= 2.1.0
- opencv-python
- ### 2、安装
- ```shell
$ hub install jde_darknet53
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
# Read from a video file
$ hub run jde_darknet53 --video_stream "/PATH/TO/VIDEO"
```
- 通过命令行方式实现多目标追踪模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、代码示例
- ```python
import paddlehub as hub
tracker = hub.Module(name="jde_darknet53")
# Read from a video file
tracker.tracking('/PATH/TO/VIDEO', output_dir='mot_result', visualization=True,
draw_threshold=0.5, use_gpu=False, from_device=False)
# or read from a image stream
# with tracker.stream_mode(output_dir='image_stream_output', visualization=True, draw_threshold=0.5, use_gpu=True):
# tracker.predict([images])
```
- ### 3、API
- ```python
def tracking(video_stream,
output_dir='',
visualization=True,
draw_threshold=0.5,
use_gpu=False)
```
- 视频预测API,完成对视频内容的多目标追踪,并存储追踪结果。
- **参数**
- video_stream (str): 视频文件的路径; <br/>
- output_dir (str): 结果保存路径的根目录,默认为当前目录; <br/>
- visualization (bool): 是否保存追踪结果;<br/>
- use\_gpu (bool): 是否使用 GPU;<br/>
- draw\_threshold (float): 预测置信度的阈值。
- ```python
def stream_mode(output_dir='',
visualization=True,
draw_threshold=0.5,
use_gpu=False)
```
- 进入图片流预测模式API,在该模式中完成对图片流的多目标追踪,并存储追踪结果。
- **参数**
- output_dir (str): 结果保存路径的根目录,默认为当前目录; <br/>
- visualization (bool): 是否保存追踪结果;<br/>
- use\_gpu (bool): 是否使用 GPU;<br/>
- draw\_threshold (float): 预测置信度的阈值。
- ```python
def predict(images: list = [])
```
- 对图片进行预测的API, 该接口必须在stream_mode API被调用后使用。
- **参数**
- images (list): 待预测的图片列表。
## 四、更新历史
* 1.0.0
初始发布
- ```shell
$ hub install jde_darknet53==1.0.0
```
architecture: JDE
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/DarkNet53_pretrained.pdparams
find_unused_parameters: True
JDE:
detector: YOLOv3
reid: JDEEmbeddingHead
tracker: JDETracker
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: JDEBBoxPostProcess
for_mot: True
DarkNet:
depth: 53
return_idx: [2, 3, 4]
freeze_norm: True
YOLOv3FPN:
freeze_norm: True
YOLOv3Head:
anchors: [[128,384], [180,540], [256,640], [512,640],
[32,96], [45,135], [64,192], [90,271],
[8,24], [11,34], [16,48], [23,68]]
anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
loss: JDEDetectionLoss
JDEBBoxPostProcess:
decode:
name: JDEBox
conf_thresh: 0.3
downsample_ratio: 32
nms:
name: MultiClassNMS
keep_top_k: 500
score_threshold: 0.01
nms_threshold: 0.5
nms_top_k: 2000
normalized: true
JDEEmbeddingHead:
anchor_levels: 3
anchor_scales: 4
embedding_dim: 512
emb_loss: JDEEmbeddingLoss
jde_loss: JDELoss
JDETracker:
det_thresh: 0.3
track_buffer: 30
min_box_area: 200
motion: KalmanFilter
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RGBReverse: {}
- AugmentHSV: {}
- LetterBoxResize: {target_size: [608, 1088]}
- MOTRandomAffine: {}
- RandomFlip: {}
- BboxXYXY2XYWH: {}
- NormalizeBox: {}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- RGBReverse: {}
- Permute: {}
batch_transforms:
- Gt2JDETargetThres:
anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
anchors: [[[128,384], [180,540], [256,640], [512,640]],
[[32,96], [45,135], [64,192], [90,271]],
[[8,24], [11,34], [16,48], [23,68]]]
downsample_ratios: [32, 16, 8]
ide_thresh: 0.5
fg_thresh: 0.5
bg_thresh: 0.4
batch_size: 4
shuffle: true
drop_last: true
use_shared_memory: true
EvalMOTReader:
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
TestMOTReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
MOTVideoStreamReader:
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
metric: MOT
num_classes: 1
# for MOT training
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
image_lists: ['mot17.train', 'caltech.all', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
# for MOT evaluation
# If you want to change the MOT evaluation dataset, please modify 'data_root'
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: MOT16/images/train
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
# for MOT video inference
TestMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video
epoch: 30
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [15, 22]
use_warmup: True
- !BurninWarmup
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
epoch: 60
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [30, 44]
use_warmup: True
- !BurninWarmup
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
use_gpu: true
log_iter: 20
save_dir: output
snapshot_epoch: 1
print_flops: false
_BASE_: [
'_base_/mot.yml',
'_base_/runtime.yml',
'_base_/optimizer_30e.yml',
'_base_/jde_darknet53.yml',
'_base_/jde_reader_1088x608.yml',
]
JDE:
detector: YOLOv3
reid: JDEEmbeddingHead
tracker: JDETracker
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: JDEBBoxPostProcess
for_mot: True
YOLOv3Head:
anchors: [[128,384], [180,540], [256,640], [512,640],
[32,96], [45,135], [64,192], [90,271],
[8,24], [11,34], [16,48], [23,68]]
anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
loss: JDEDetectionLoss
JDETracker:
det_thresh: 0.3
track_buffer: 30
min_box_area: 200
motion: KalmanFilter
JDEBBoxPostProcess:
decode:
name: JDEBox
conf_thresh: 0.5
downsample_ratio: 32
nms:
name: MultiClassNMS
keep_top_k: 500
score_threshold: 0.01
nms_threshold: 0.4
nms_top_k: 2000
normalized: true
return_index: true
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import six
from collections.abc import Mapping
from collections import deque
from ppdet.core.workspace import register, serializable
from ppdet.utils.logger import setup_logger
from ppdet.data.reader import BaseDataLoader, Compose
from paddle.fluid.dataloader.collate import default_collate_fn
import cv2
from imageio import imread, imwrite
import numpy as np
import paddle
logger = setup_logger(__name__)
@register
@serializable
class MOTVideoStream:
"""
Load MOT dataset with MOT format from video stream.
Args:
video_stream (str): path or url of the video file, default ''.
keep_ori_im (bool): whether to keep original image, default False.
Set True when used during MOT model inference while saving
images or video, or used in DeepSORT.
"""
def __init__(self, video_stream=None, keep_ori_im=False, **kwargs):
self.video_stream = video_stream
self.keep_ori_im = keep_ori_im
self._curr_iter = 0
self.transform = None
try:
if video_stream == None:
print('No video stream is specified, please check the --video_stream option.')
raise FileNotFoundError("No video_stream is specified.")
self.stream = cv2.VideoCapture(video_stream)
if not self.stream.isOpened():
raise Exception("Open video stream Error!")
except Exception as e:
print('Failed to read {}.'.format(video_stream))
raise e
self.videoframeraw_dir = os.path.splitext(os.path.basename(self.video_stream))[0] + '_raw'
if not os.path.exists(self.videoframeraw_dir):
os.makedirs(self.videoframeraw_dir)
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
def set_transform(self, transform):
self.transform = transform
def set_epoch(self, epoch_id):
self._epoch = epoch_id
def parse_dataset(self):
pass
def __iter__(self):
ct = 0
while True:
ret, frame = self.stream.read()
if ret:
imgname = os.path.join(self.videoframeraw_dir, 'frame{}.png'.format(ct))
cv2.imwrite(imgname, frame)
image = imread(imgname)
rec = {'im_id': np.array([ct]), 'im_file': imgname}
if self.keep_ori_im:
rec.update({'keep_ori_im': 1})
rec['curr_iter'] = self._curr_iter
self._curr_iter += 1
ct += 1
if self.transform:
yield self.transform(rec)
else:
yield rec
else:
return
@register
@serializable
class MOTImageStream:
"""
Load MOT dataset with MOT format from image stream.
Args:
keep_ori_im (bool): whether to keep original image, default False.
Set True when used during MOT model inference while saving
images or video, or used in DeepSORT.
"""
def __init__(self, sample_num=-1, keep_ori_im=False, **kwargs):
self.keep_ori_im = keep_ori_im
self._curr_iter = 0
self.transform = None
self.imagequeue = deque()
self.frameraw_dir = 'inputimages_raw'
if not os.path.exists(self.frameraw_dir):
os.makedirs(self.frameraw_dir)
def add_image(self, image):
self.imagequeue.append(image)
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
def set_transform(self, transform):
self.transform = transform
def set_epoch(self, epoch_id):
self._epoch = epoch_id
def parse_dataset(self):
pass
def __iter__(self):
ct = 0
while True:
if self.imagequeue:
frame = self.imagequeue.popleft()
imgname = os.path.join(self.frameraw_dir, 'frame{}.png'.format(ct))
cv2.imwrite(imgname, frame)
image = imread(imgname)
rec = {'im_id': np.array([ct]), 'im_file': imgname}
if self.keep_ori_im:
rec.update({'keep_ori_im': 1})
rec['curr_iter'] = self._curr_iter
self._curr_iter += 1
ct += 1
if self.transform:
yield self.transform(rec)
else:
yield rec
else:
return
@register
class MOTVideoStreamReader:
__shared__ = ['num_classes']
def __init__(self, sample_transforms=[], batch_size=1, drop_last=False, num_classes=1, **kwargs):
self._sample_transforms = Compose(sample_transforms, num_classes=num_classes)
self.batch_size = batch_size
self.drop_last = drop_last
self.num_classes = num_classes
self.kwargs = kwargs
def __call__(
self,
dataset,
worker_num,
):
self.dataset = dataset
# get data
self.dataset.set_transform(self._sample_transforms)
# set kwargs
self.dataset.set_kwargs(**self.kwargs)
self.loader = iter(self.dataset)
return self
def __len__(self):
return sys.maxint
def __iter__(self):
return self
def to_tensor(self, batch):
paddle.disable_static()
if isinstance(batch, np.ndarray):
batch = paddle.to_tensor(batch)
elif isinstance(batch, Mapping):
batch = {key: self.to_tensor(batch[key]) for key in batch}
return batch
def __next__(self):
try:
batch = []
for i in range(self.batch_size):
batch.append(next(self.loader))
batch = default_collate_fn(batch)
return self.to_tensor(batch)
except StopIteration as e:
raise e
def next(self):
# python2 compatibility
return self.__next__()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import signal
import glob
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Tracker
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.logger import setup_logger
import paddlehub as hub
from paddlehub.module.module import moduleinfo, serving, runnable
import cv2
from .tracker import StreamTracker
logger = setup_logger('Predict')
@moduleinfo(name="jde_darknet53",
type="CV/multiple_object_tracking",
author="paddlepaddle",
author_email="",
summary="JDE is a joint detection and appearance embedding model for multiple object tracking.",
version="1.0.0")
class JDETracker_1088x608:
def __init__(self):
self.pretrained_model = os.path.join(self.directory, "jde_darknet53_30e_1088x608")
def tracking(self, video_stream, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False):
'''
Track a video, and save the prediction results into output_dir, if visualization is set as True.
video_stream: the video path
output_dir: specify the dir to save the results
visualization: if True, save the results as a video, otherwise not.
draw_threshold: the threshold for the prediction results
use_gpu: if True, use gpu to perform the computation, otherwise cpu.
'''
self.video_stream = video_stream
self.output_dir = output_dir
self.visualization = visualization
self.draw_threshold = draw_threshold
self.use_gpu = use_gpu
cfg = load_config(os.path.join(self.directory, 'config', 'jde_darknet53_30e_1088x608.yml'))
check_config(cfg)
place = 'gpu:0' if use_gpu else 'cpu'
place = paddle.set_device(place)
paddle.disable_static()
tracker = StreamTracker(cfg, mode='test')
# load weights
tracker.load_weights_jde(self.pretrained_model)
signal.signal(signal.SIGINT, self.signalhandler)
# inference
tracker.videostream_predict(video_stream=video_stream,
output_dir=output_dir,
data_type='mot',
model_type='JDE',
visualization=visualization,
draw_threshold=draw_threshold)
def stream_mode(self, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False):
'''
Entering the stream mode enables image stream prediction. Users can predict the images like a stream and save the results to a video.
output_dir: specify the dir to save the results
visualization: if True, save the results as a video, otherwise not.
draw_threshold: the threshold for the prediction results
use_gpu: if True, use gpu to perform the computation, otherwise cpu.
'''
self.output_dir = output_dir
self.visualization = visualization
self.draw_threshold = draw_threshold
self.use_gpu = use_gpu
cfg = load_config(os.path.join(self.directory, 'config', 'jde_darknet53_30e_1088x608.yml'))
check_config(cfg)
place = 'gpu:0' if use_gpu else 'cpu'
place = paddle.set_device(place)
paddle.disable_static()
self.tracker = StreamTracker(cfg, mode='test')
# load weights
self.tracker.load_weights_jde(self.pretrained_model)
signal.signal(signal.SIGINT, self.signalhandler)
return self
def __enter__(self):
self.tracker_generator = self.tracker.imagestream_predict(self.output_dir,
data_type='mot',
model_type='JDE',
visualization=self.visualization,
draw_threshold=self.draw_threshold)
next(self.tracker_generator)
def __exit__(self, exc_type, exc_value, traceback):
seq = 'inputimages'
save_dir = os.path.join(self.output_dir, 'mot_outputs', seq) if self.visualization else None
if self.visualization:
#### Save using ffmpeg
#output_video_path = os.path.join(save_dir, '..', '{}_vis.mp4'.format(seq))
#cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
# save_dir, output_video_path)
#os.system(cmd_str)
#### Save using opencv
output_video_path = os.path.join(save_dir, '..', '{}_vis.avi'.format(seq))
imgnames = glob.glob(os.path.join(save_dir, '*.jpg'))
if len(imgnames) == 0:
logger.info('No output images to save for video')
return
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=[img.shape[1], img.shape[0]])
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
video_writer.write(img)
video_writer.release()
logger.info('Save video in {}'.format(output_video_path))
def predict(self, images: list = []):
'''
Predict the images. This method should called in stream_mode.
images: the image list used for prediction.
Example:
tracker = hub.Module('fairmot_dla34')
with tracker.stream_mode(output_dir='image_stream_output', visualization=True, draw_threshold=0.5, use_gpu=True):
tracker.predict([images])
'''
length = len(images)
if length == 0:
print('No images provided.')
return
for image in images:
self.tracker.dataset.add_image(image)
try:
results = next(self.tracker_generator)
except StopIteration as e:
return
return results[-length:]
@runnable
def run_cmd(self, argvs: list):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
self.args = self.parser.parse_args(argvs)
self.tracking(
video_stream=self.args.video_stream,
output_dir=self.args.output_dir,
visualization=self.args.visualization,
draw_threshold=self.args.draw_threshold,
use_gpu=self.args.use_gpu,
)
def signalhandler(self, signum, frame):
seq = os.path.splitext(os.path.basename(self.video_stream))[0]
save_dir = os.path.join(self.output_dir, 'mot_outputs', seq) if self.visualization else None
if self.visualization:
#### Save using ffmpeg
#output_video_path = os.path.join(save_dir, '..', '{}_vis.mp4'.format(seq))
#cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
# save_dir, output_video_path)
#os.system(cmd_str)
#### Save using opencv
output_video_path = os.path.join(save_dir, '..', '{}_vis.avi'.format(seq))
imgnames = glob.glob(os.path.join(save_dir, '*.jpg'))
if len(imgnames) == 0:
logger.info('No output images to save for video')
return
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path,
fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps=30,
frameSize=[img.shape[1], img.shape[0]])
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
video_writer.write(img)
video_writer.release()
logger.info('Save video in {}'.format(output_video_path))
print('Program Interrupted! Save video in {}'.format(output_video_path))
exit(0)
def add_module_config_arg(self):
"""
Add the command config options.
"""
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument('--output_dir',
type=str,
default='mot_result',
help='Directory name for output tracking results.')
self.arg_config_group.add_argument('--visualization',
action='store_true',
help="whether to save output as images.")
self.arg_config_group.add_argument("--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--video_stream',
type=str,
help="path to video stream, can be a video file or stream device number.")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import glob
import paddle
import numpy as np
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import Timer, load_det_results
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric
import ppdet.utils.stats as stats
from ppdet.engine.callbacks import Callback, ComposeCallback
from ppdet.utils.logger import setup_logger
from .dataset import MOTVideoStream, MOTImageStream
logger = setup_logger(__name__)
class StreamTracker(object):
def __init__(self, cfg, mode='eval'):
self.cfg = cfg
assert mode.lower() in ['test', 'eval'], \
"mode should be 'test' or 'eval'"
self.mode = mode.lower()
self.optimizer = None
# build model
self.model = create(cfg.architecture)
self.status = {}
self.start_epoch = 0
def load_weights_jde(self, weights):
load_weight(self.model, weights, self.optimizer)
def _eval_seq_jde(self, dataloader, save_dir=None, show_image=False, frame_rate=30, draw_threshold=0):
if save_dir:
if not os.path.exists(save_dir): os.makedirs(save_dir)
tracker = self.model.tracker
tracker.max_time_lost = int(frame_rate / 30.0 * tracker.track_buffer)
timer = Timer()
results = []
frame_id = 0
self.status['mode'] = 'track'
self.model.eval()
for step_id, data in enumerate(dataloader):
#print('data', data)
self.status['step_id'] = step_id
if frame_id % 40 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
# forward
timer.tic()
pred_dets, pred_embs = self.model(data)
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < draw_threshold: continue
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(tscore)
timer.toc()
# save results
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs, online_scores, timer.average_time, show_image,
save_dir)
frame_id += 1
return results, frame_id, timer.average_time, timer.calls
def _eval_seq_jde_single_image(self, iterator, save_dir=None, show_image=False, draw_threshold=0):
if save_dir:
if not os.path.exists(save_dir): os.makedirs(save_dir)
tracker = self.model.tracker
results = []
frame_id = 0
self.status['mode'] = 'track'
self.model.eval()
timer = Timer()
while True:
try:
data = next(iterator)
timer.tic()
with paddle.no_grad():
pred_dets, pred_embs = self.model(data)
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < draw_threshold: continue
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(tscore)
timer.toc()
# save results
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs, online_scores, timer.average_time,
show_image, save_dir)
frame_id += 1
yield results, frame_id
except StopIteration as e:
return
def imagestream_predict(self, output_dir, data_type='mot', model_type='JDE', visualization=True,
draw_threshold=0.5):
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root)
assert data_type in ['mot', 'kitti'], \
"data_type should be 'mot' or 'kitti'"
assert model_type in ['JDE', 'FairMOT'], \
"model_type should be 'JDE', or 'FairMOT'"
seq = 'inputimages'
self.dataset = MOTImageStream(keep_ori_im=True)
save_dir = os.path.join(output_dir, 'mot_outputs', seq) if visualization else None
self.dataloader = create('MOTVideoStreamReader')(self.dataset, 0)
self.dataloader_iter = iter(self.dataloader)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
if model_type in ['JDE', 'FairMOT']:
generator = self._eval_seq_jde_single_image(
self.dataloader_iter, save_dir=save_dir, draw_threshold=draw_threshold)
else:
raise ValueError(model_type)
yield
results = []
while True:
with paddle.no_grad():
try:
results, nf = next(generator)
yield results
except StopIteration as e:
self.write_mot_results(result_filename, results, data_type)
return
def videostream_predict(self,
video_stream,
output_dir,
data_type='mot',
model_type='JDE',
visualization=True,
draw_threshold=0.5):
assert video_stream is not None, \
"--video_file or --image_dir should be set."
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root)
assert data_type in ['mot', 'kitti'], \
"data_type should be 'mot' or 'kitti'"
assert model_type in ['JDE', 'FairMOT'], \
"model_type should be 'JDE', or 'FairMOT'"
seq = os.path.splitext(os.path.basename(video_stream))[0]
self.dataset = MOTVideoStream(video_stream, keep_ori_im=True)
save_dir = os.path.join(output_dir, 'mot_outputs', seq) if visualization else None
dataloader = create('MOTVideoStreamReader')(self.dataset, 0)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
with paddle.no_grad():
if model_type in ['JDE', 'FairMOT']:
results, nf, ta, tc = self._eval_seq_jde(dataloader, save_dir=save_dir, draw_threshold=draw_threshold)
else:
raise ValueError(model_type)
self.write_mot_results(result_filename, results, data_type)
if visualization:
#### Save using ffmpeg
#output_video_path = os.path.join(save_dir, '..', '{}_vis.mp4'.format(seq))
#cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
# save_dir, output_video_path)
#os.system(cmd_str)
#### Save using opencv
output_video_path = os.path.join(save_dir, '..', '{}_vis.avi'.format(seq))
imgnames = glob.glob(os.path.join(save_dir, '*.jpg'))
if len(imgnames) == 0:
logger.info('No output images to save for video')
return
img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
video_writer = cv2.VideoWriter(output_video_path, fourcc=cv2.VideoWriter_fourcc('M','J','P','G'), fps=30, frameSize=[img.shape[1],img.shape[0]])
for i in range(len(imgnames)):
imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
img = cv2.imread(imgpath)
video_writer.write(img)
video_writer.release()
logger.info('Save video in {}'.format(output_video_path))
def write_mot_results(self, filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=score)
f.write(line)
logger.info('MOT results save in {}'.format(filename))
def save_results(self, data, frame_id, online_ids, online_tlwhs, online_scores, average_time, show_image, save_dir):
if show_image or save_dir is not None:
assert 'ori_image' in data
img0 = data['ori_image'].numpy()[0]
online_im = mot_vis.plot_tracking(
img0, online_tlwhs, online_ids, online_scores, frame_id=frame_id, fps=1. / average_time)
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册