diff --git a/modules/video/multiple_object_tracking/jde_darknet53/README.md b/modules/video/multiple_object_tracking/jde_darknet53/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..17398169d0cd1e4a7bb77d0102ad675552aea2aa
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/README.md
@@ -0,0 +1,124 @@
+# jde_darknet53
+
+|模型名称|jde_darknet53|
+| :--- | :---: |
+|类别|视频 - 多目标追踪|
+|网络|YOLOv3|
+|数据集|Caltech Pedestrian+CityPersons+CUHK-SYSU+PRW+ETHZ+MOT17|
+|是否支持Fine-tuning|否|
+|模型大小|420MB|
+|最新更新日期|2021-08-26|
+|数据指标|-|
+
+
+## 一、模型基本信息
+
+- ### 应用效果展示
+ - 样例结果示例:
+
+
+
+
+- ### 模型介绍
+
+ - JDE(Joint Detection and Embedding)是在一个单一的共享神经网络中同时学习目标检测任务和embedding任务,并同时输出检测结果和对应的外观embedding匹配的算法。JDE原论文是基于Anchor Base的YOLOv3检测器新增加一个ReID分支学习embedding,训练过程被构建为一个多任务联合学习问题,兼顾精度和速度。
+
+ - 更多详情参考:[Towards Real-Time Multi-Object Tracking](https://arxiv.org/abs/1909.12605)
+
+
+
+## 二、安装
+
+- ### 1、环境依赖
+
+ - ppdet >= 2.1.0
+
+ - opencv-python
+
+- ### 2、安装
+
+ - ```shell
+ $ hub install jde_darknet53
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
+
+
+## 三、模型API预测
+
+- ### 1、命令行预测
+
+ - ```shell
+ # Read from a video file
+ $ hub run jde_darknet53 --video_stream "/PATH/TO/VIDEO"
+ ```
+ - 通过命令行方式实现多目标追踪模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
+
+- ### 2、代码示例
+
+ - ```python
+ import paddlehub as hub
+
+ tracker = hub.Module(name="jde_darknet53")
+ # Read from a video file
+ tracker.tracking('/PATH/TO/VIDEO', output_dir='mot_result', visualization=True,
+ draw_threshold=0.5, use_gpu=False, from_device=False)
+ # or read from a image stream
+ # with tracker.stream_mode(output_dir='image_stream_output', visualization=True, draw_threshold=0.5, use_gpu=True):
+ # tracker.predict([images])
+ ```
+
+- ### 3、API
+
+ - ```python
+ def tracking(video_stream,
+ output_dir='',
+ visualization=True,
+ draw_threshold=0.5,
+ use_gpu=False)
+ ```
+ - 视频预测API,完成对视频内容的多目标追踪,并存储追踪结果。
+
+ - **参数**
+
+ - video_stream (str): 视频文件的路径;
+ - output_dir (str): 结果保存路径的根目录,默认为当前目录;
+ - visualization (bool): 是否保存追踪结果;
+ - use\_gpu (bool): 是否使用 GPU;
+ - draw\_threshold (float): 预测置信度的阈值。
+
+ - ```python
+ def stream_mode(output_dir='',
+ visualization=True,
+ draw_threshold=0.5,
+ use_gpu=False)
+ ```
+ - 进入图片流预测模式API,在该模式中完成对图片流的多目标追踪,并存储追踪结果。
+
+ - **参数**
+
+ - output_dir (str): 结果保存路径的根目录,默认为当前目录;
+ - visualization (bool): 是否保存追踪结果;
+ - use\_gpu (bool): 是否使用 GPU;
+ - draw\_threshold (float): 预测置信度的阈值。
+
+ - ```python
+ def predict(images: list = [])
+ ```
+ - 对图片进行预测的API, 该接口必须在stream_mode API被调用后使用。
+
+ - **参数**
+
+ - images (list): 待预测的图片列表。
+
+
+
+## 四、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+ - ```shell
+ $ hub install jde_darknet53==1.0.0
+ ```
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml
new file mode 100644
index 0000000000000000000000000000000000000000..73faa52f662e7db24ef40c25c029561225d1a3b8
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml
@@ -0,0 +1,56 @@
+architecture: JDE
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/DarkNet53_pretrained.pdparams
+find_unused_parameters: True
+
+JDE:
+ detector: YOLOv3
+ reid: JDEEmbeddingHead
+ tracker: JDETracker
+
+YOLOv3:
+ backbone: DarkNet
+ neck: YOLOv3FPN
+ yolo_head: YOLOv3Head
+ post_process: JDEBBoxPostProcess
+ for_mot: True
+
+DarkNet:
+ depth: 53
+ return_idx: [2, 3, 4]
+ freeze_norm: True
+
+YOLOv3FPN:
+ freeze_norm: True
+
+YOLOv3Head:
+ anchors: [[128,384], [180,540], [256,640], [512,640],
+ [32,96], [45,135], [64,192], [90,271],
+ [8,24], [11,34], [16,48], [23,68]]
+ anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
+ loss: JDEDetectionLoss
+
+JDEBBoxPostProcess:
+ decode:
+ name: JDEBox
+ conf_thresh: 0.3
+ downsample_ratio: 32
+ nms:
+ name: MultiClassNMS
+ keep_top_k: 500
+ score_threshold: 0.01
+ nms_threshold: 0.5
+ nms_top_k: 2000
+ normalized: true
+
+JDEEmbeddingHead:
+ anchor_levels: 3
+ anchor_scales: 4
+ embedding_dim: 512
+ emb_loss: JDEEmbeddingLoss
+ jde_loss: JDELoss
+
+JDETracker:
+ det_thresh: 0.3
+ track_buffer: 30
+ min_box_area: 200
+ motion: KalmanFilter
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_reader_1088x608.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_reader_1088x608.yml
new file mode 100644
index 0000000000000000000000000000000000000000..527600681163527b0365792ab5dfa4f1aea2f120
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_reader_1088x608.yml
@@ -0,0 +1,57 @@
+worker_num: 2
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RGBReverse: {}
+ - AugmentHSV: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - MOTRandomAffine: {}
+ - RandomFlip: {}
+ - BboxXYXY2XYWH: {}
+ - NormalizeBox: {}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+ - RGBReverse: {}
+ - Permute: {}
+ batch_transforms:
+ - Gt2JDETargetThres:
+ anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
+ anchors: [[[128,384], [180,540], [256,640], [512,640]],
+ [[32,96], [45,135], [64,192], [90,271]],
+ [[8,24], [11,34], [16,48], [23,68]]]
+ downsample_ratios: [32, 16, 8]
+ ide_thresh: 0.5
+ fg_thresh: 0.5
+ bg_thresh: 0.4
+ batch_size: 4
+ shuffle: true
+ drop_last: true
+ use_shared_memory: true
+
+
+EvalMOTReader:
+ sample_transforms:
+ - Decode: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+
+TestMOTReader:
+ inputs_def:
+ image_shape: [3, 608, 1088]
+ sample_transforms:
+ - Decode: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+
+MOTVideoStreamReader:
+ sample_transforms:
+ - Decode: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+ - Permute: {}
+ batch_size: 1
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/mot.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/mot.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7107da4905e88847aba29e66346a5c05bc418462
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/mot.yml
@@ -0,0 +1,23 @@
+metric: MOT
+num_classes: 1
+
+# for MOT training
+TrainDataset:
+ !MOTDataSet
+ dataset_dir: dataset/mot
+ image_lists: ['mot17.train', 'caltech.all', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train']
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
+
+# for MOT evaluation
+# If you want to change the MOT evaluation dataset, please modify 'data_root'
+EvalMOTDataset:
+ !MOTImageFolder
+ dataset_dir: dataset/mot
+ data_root: MOT16/images/train
+ keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
+
+# for MOT video inference
+TestMOTDataset:
+ !MOTImageFolder
+ dataset_dir: dataset/mot
+ keep_ori_im: True # set True if save visualization images or video
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/optimizer_30e.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/optimizer_30e.yml
new file mode 100644
index 0000000000000000000000000000000000000000..eec33930926877319aff8b00de516068e646aaea
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/optimizer_30e.yml
@@ -0,0 +1,19 @@
+epoch: 30
+
+LearningRate:
+ base_lr: 0.01
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [15, 22]
+ use_warmup: True
+ - !BurninWarmup
+ steps: 1000
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0001
+ type: L2
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/optimizer_60e.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/optimizer_60e.yml
new file mode 100644
index 0000000000000000000000000000000000000000..986764a42f6bfb24d09cb23b49ed6931dbed9352
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/optimizer_60e.yml
@@ -0,0 +1,19 @@
+epoch: 60
+
+LearningRate:
+ base_lr: 0.01
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [30, 44]
+ use_warmup: True
+ - !BurninWarmup
+ steps: 1000
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0001
+ type: L2
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/runtime.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/runtime.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c502ddabeb93d95a850fe7cb83a1e68ceff3a4e4
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/runtime.yml
@@ -0,0 +1,5 @@
+use_gpu: true
+log_iter: 20
+save_dir: output
+snapshot_epoch: 1
+print_flops: false
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d2ac3aee460aaa378dcef11c3a3fce9aa4c29f05
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml
@@ -0,0 +1,46 @@
+_BASE_: [
+ '_base_/mot.yml',
+ '_base_/runtime.yml',
+ '_base_/optimizer_30e.yml',
+ '_base_/jde_darknet53.yml',
+ '_base_/jde_reader_1088x608.yml',
+]
+
+JDE:
+ detector: YOLOv3
+ reid: JDEEmbeddingHead
+ tracker: JDETracker
+
+YOLOv3:
+ backbone: DarkNet
+ neck: YOLOv3FPN
+ yolo_head: YOLOv3Head
+ post_process: JDEBBoxPostProcess
+ for_mot: True
+
+YOLOv3Head:
+ anchors: [[128,384], [180,540], [256,640], [512,640],
+ [32,96], [45,135], [64,192], [90,271],
+ [8,24], [11,34], [16,48], [23,68]]
+ anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
+ loss: JDEDetectionLoss
+
+JDETracker:
+ det_thresh: 0.3
+ track_buffer: 30
+ min_box_area: 200
+ motion: KalmanFilter
+
+JDEBBoxPostProcess:
+ decode:
+ name: JDEBox
+ conf_thresh: 0.5
+ downsample_ratio: 32
+ nms:
+ name: MultiClassNMS
+ keep_top_k: 500
+ score_threshold: 0.01
+ nms_threshold: 0.4
+ nms_top_k: 2000
+ normalized: true
+ return_index: true
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/dataset.py b/modules/video/multiple_object_tracking/jde_darknet53/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff584d41995237dce107e1f103d862af942ed6a5
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/dataset.py
@@ -0,0 +1,210 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+import six
+from collections.abc import Mapping
+from collections import deque
+
+from ppdet.core.workspace import register, serializable
+from ppdet.utils.logger import setup_logger
+from ppdet.data.reader import BaseDataLoader, Compose
+from paddle.fluid.dataloader.collate import default_collate_fn
+import cv2
+from imageio import imread, imwrite
+import numpy as np
+import paddle
+
+logger = setup_logger(__name__)
+
+
+@register
+@serializable
+class MOTVideoStream:
+ """
+ Load MOT dataset with MOT format from video stream.
+ Args:
+ video_stream (str): path or url of the video file, default ''.
+ keep_ori_im (bool): whether to keep original image, default False.
+ Set True when used during MOT model inference while saving
+ images or video, or used in DeepSORT.
+ """
+ def __init__(self, video_stream=None, keep_ori_im=False, **kwargs):
+ self.video_stream = video_stream
+ self.keep_ori_im = keep_ori_im
+ self._curr_iter = 0
+ self.transform = None
+ try:
+ if video_stream == None:
+ print('No video stream is specified, please check the --video_stream option.')
+ raise FileNotFoundError("No video_stream is specified.")
+ self.stream = cv2.VideoCapture(video_stream)
+ if not self.stream.isOpened():
+ raise Exception("Open video stream Error!")
+ except Exception as e:
+ print('Failed to read {}.'.format(video_stream))
+ raise e
+
+ self.videoframeraw_dir = os.path.splitext(os.path.basename(self.video_stream))[0] + '_raw'
+ if not os.path.exists(self.videoframeraw_dir):
+ os.makedirs(self.videoframeraw_dir)
+
+ def set_kwargs(self, **kwargs):
+ self.mixup_epoch = kwargs.get('mixup_epoch', -1)
+ self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
+ self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
+
+ def set_transform(self, transform):
+ self.transform = transform
+
+ def set_epoch(self, epoch_id):
+ self._epoch = epoch_id
+
+ def parse_dataset(self):
+ pass
+
+ def __iter__(self):
+ ct = 0
+ while True:
+ ret, frame = self.stream.read()
+ if ret:
+ imgname = os.path.join(self.videoframeraw_dir, 'frame{}.png'.format(ct))
+ cv2.imwrite(imgname, frame)
+ image = imread(imgname)
+ rec = {'im_id': np.array([ct]), 'im_file': imgname}
+ if self.keep_ori_im:
+ rec.update({'keep_ori_im': 1})
+ rec['curr_iter'] = self._curr_iter
+ self._curr_iter += 1
+ ct += 1
+ if self.transform:
+ yield self.transform(rec)
+ else:
+ yield rec
+ else:
+ return
+
+
+@register
+@serializable
+class MOTImageStream:
+ """
+ Load MOT dataset with MOT format from image stream.
+ Args:
+ keep_ori_im (bool): whether to keep original image, default False.
+ Set True when used during MOT model inference while saving
+ images or video, or used in DeepSORT.
+ """
+ def __init__(self, sample_num=-1, keep_ori_im=False, **kwargs):
+ self.keep_ori_im = keep_ori_im
+ self._curr_iter = 0
+ self.transform = None
+ self.imagequeue = deque()
+
+ self.frameraw_dir = 'inputimages_raw'
+ if not os.path.exists(self.frameraw_dir):
+ os.makedirs(self.frameraw_dir)
+
+ def add_image(self, image):
+ self.imagequeue.append(image)
+
+ def set_kwargs(self, **kwargs):
+ self.mixup_epoch = kwargs.get('mixup_epoch', -1)
+ self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
+ self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
+
+ def set_transform(self, transform):
+ self.transform = transform
+
+ def set_epoch(self, epoch_id):
+ self._epoch = epoch_id
+
+ def parse_dataset(self):
+ pass
+
+ def __iter__(self):
+ ct = 0
+ while True:
+ if self.imagequeue:
+ frame = self.imagequeue.popleft()
+ imgname = os.path.join(self.frameraw_dir, 'frame{}.png'.format(ct))
+ cv2.imwrite(imgname, frame)
+ image = imread(imgname)
+ rec = {'im_id': np.array([ct]), 'im_file': imgname}
+ if self.keep_ori_im:
+ rec.update({'keep_ori_im': 1})
+ rec['curr_iter'] = self._curr_iter
+ self._curr_iter += 1
+ ct += 1
+ if self.transform:
+ yield self.transform(rec)
+ else:
+ yield rec
+ else:
+ return
+
+
+@register
+class MOTVideoStreamReader:
+ __shared__ = ['num_classes']
+
+ def __init__(self, sample_transforms=[], batch_size=1, drop_last=False, num_classes=1, **kwargs):
+ self._sample_transforms = Compose(sample_transforms, num_classes=num_classes)
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+ self.num_classes = num_classes
+ self.kwargs = kwargs
+
+ def __call__(
+ self,
+ dataset,
+ worker_num,
+ ):
+ self.dataset = dataset
+ # get data
+ self.dataset.set_transform(self._sample_transforms)
+ # set kwargs
+ self.dataset.set_kwargs(**self.kwargs)
+
+ self.loader = iter(self.dataset)
+ return self
+
+ def __len__(self):
+ return sys.maxint
+
+ def __iter__(self):
+ return self
+
+ def to_tensor(self, batch):
+ paddle.disable_static()
+ if isinstance(batch, np.ndarray):
+ batch = paddle.to_tensor(batch)
+ elif isinstance(batch, Mapping):
+ batch = {key: self.to_tensor(batch[key]) for key in batch}
+ return batch
+
+ def __next__(self):
+ try:
+ batch = []
+ for i in range(self.batch_size):
+ batch.append(next(self.loader))
+ batch = default_collate_fn(batch)
+ return self.to_tensor(batch)
+
+ except StopIteration as e:
+ raise e
+
+ def next(self):
+ # python2 compatibility
+ return self.__next__()
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/module.py b/modules/video/multiple_object_tracking/jde_darknet53/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..98b0c287596f004689e7c43a8ee17411c0fc9bf7
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/module.py
@@ -0,0 +1,244 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+import signal
+import glob
+import argparse
+
+import paddle
+from ppdet.core.workspace import load_config, merge_config
+from ppdet.engine import Tracker
+from ppdet.utils.check import check_gpu, check_version, check_config
+from ppdet.utils.logger import setup_logger
+import paddlehub as hub
+from paddlehub.module.module import moduleinfo, serving, runnable
+import cv2
+
+from .tracker import StreamTracker
+
+logger = setup_logger('Predict')
+
+
+@moduleinfo(name="jde_darknet53",
+ type="CV/multiple_object_tracking",
+ author="paddlepaddle",
+ author_email="",
+ summary="JDE is a joint detection and appearance embedding model for multiple object tracking.",
+ version="1.0.0")
+class JDETracker_1088x608:
+ def __init__(self):
+ self.pretrained_model = os.path.join(self.directory, "jde_darknet53_30e_1088x608")
+
+ def tracking(self, video_stream, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False):
+ '''
+ Track a video, and save the prediction results into output_dir, if visualization is set as True.
+
+ video_stream: the video path
+ output_dir: specify the dir to save the results
+ visualization: if True, save the results as a video, otherwise not.
+ draw_threshold: the threshold for the prediction results
+ use_gpu: if True, use gpu to perform the computation, otherwise cpu.
+ '''
+ self.video_stream = video_stream
+ self.output_dir = output_dir
+ self.visualization = visualization
+ self.draw_threshold = draw_threshold
+ self.use_gpu = use_gpu
+
+ cfg = load_config(os.path.join(self.directory, 'config', 'jde_darknet53_30e_1088x608.yml'))
+ check_config(cfg)
+
+ place = 'gpu:0' if use_gpu else 'cpu'
+ place = paddle.set_device(place)
+
+ paddle.disable_static()
+ tracker = StreamTracker(cfg, mode='test')
+
+ # load weights
+ tracker.load_weights_jde(self.pretrained_model)
+ signal.signal(signal.SIGINT, self.signalhandler)
+ # inference
+ tracker.videostream_predict(video_stream=video_stream,
+ output_dir=output_dir,
+ data_type='mot',
+ model_type='JDE',
+ visualization=visualization,
+ draw_threshold=draw_threshold)
+
+ def stream_mode(self, output_dir='mot_result', visualization=True, draw_threshold=0.5, use_gpu=False):
+ '''
+ Entering the stream mode enables image stream prediction. Users can predict the images like a stream and save the results to a video.
+
+ output_dir: specify the dir to save the results
+ visualization: if True, save the results as a video, otherwise not.
+ draw_threshold: the threshold for the prediction results
+ use_gpu: if True, use gpu to perform the computation, otherwise cpu.
+ '''
+ self.output_dir = output_dir
+ self.visualization = visualization
+ self.draw_threshold = draw_threshold
+ self.use_gpu = use_gpu
+
+ cfg = load_config(os.path.join(self.directory, 'config', 'jde_darknet53_30e_1088x608.yml'))
+ check_config(cfg)
+
+ place = 'gpu:0' if use_gpu else 'cpu'
+ place = paddle.set_device(place)
+
+ paddle.disable_static()
+ self.tracker = StreamTracker(cfg, mode='test')
+
+ # load weights
+ self.tracker.load_weights_jde(self.pretrained_model)
+ signal.signal(signal.SIGINT, self.signalhandler)
+ return self
+
+ def __enter__(self):
+ self.tracker_generator = self.tracker.imagestream_predict(self.output_dir,
+ data_type='mot',
+ model_type='JDE',
+ visualization=self.visualization,
+ draw_threshold=self.draw_threshold)
+ next(self.tracker_generator)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ seq = 'inputimages'
+ save_dir = os.path.join(self.output_dir, 'mot_outputs', seq) if self.visualization else None
+ if self.visualization:
+ #### Save using ffmpeg
+ #output_video_path = os.path.join(save_dir, '..', '{}_vis.mp4'.format(seq))
+ #cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
+ # save_dir, output_video_path)
+ #os.system(cmd_str)
+ #### Save using opencv
+ output_video_path = os.path.join(save_dir, '..', '{}_vis.avi'.format(seq))
+ imgnames = glob.glob(os.path.join(save_dir, '*.jpg'))
+ if len(imgnames) == 0:
+ logger.info('No output images to save for video')
+ return
+ img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
+ video_writer = cv2.VideoWriter(output_video_path,
+ fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
+ fps=30,
+ frameSize=[img.shape[1], img.shape[0]])
+ for i in range(len(imgnames)):
+ imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
+ img = cv2.imread(imgpath)
+ video_writer.write(img)
+ video_writer.release()
+ logger.info('Save video in {}'.format(output_video_path))
+
+ def predict(self, images: list = []):
+ '''
+ Predict the images. This method should called in stream_mode.
+
+ images: the image list used for prediction.
+
+ Example:
+ tracker = hub.Module('fairmot_dla34')
+ with tracker.stream_mode(output_dir='image_stream_output', visualization=True, draw_threshold=0.5, use_gpu=True):
+ tracker.predict([images])
+ '''
+ length = len(images)
+ if length == 0:
+ print('No images provided.')
+ return
+ for image in images:
+ self.tracker.dataset.add_image(image)
+ try:
+ results = next(self.tracker_generator)
+ except StopIteration as e:
+ return
+
+ return results[-length:]
+
+ @runnable
+ def run_cmd(self, argvs: list):
+ """
+ Run as a command.
+ """
+ self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
+ prog='hub run {}'.format(self.name),
+ usage='%(prog)s',
+ add_help=True)
+
+ self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
+ self.arg_config_group = self.parser.add_argument_group(
+ title="Config options", description="Run configuration for controlling module behavior, not required.")
+ self.add_module_config_arg()
+ self.add_module_input_arg()
+ self.args = self.parser.parse_args(argvs)
+ self.tracking(
+ video_stream=self.args.video_stream,
+ output_dir=self.args.output_dir,
+ visualization=self.args.visualization,
+ draw_threshold=self.args.draw_threshold,
+ use_gpu=self.args.use_gpu,
+ )
+
+ def signalhandler(self, signum, frame):
+ seq = os.path.splitext(os.path.basename(self.video_stream))[0]
+ save_dir = os.path.join(self.output_dir, 'mot_outputs', seq) if self.visualization else None
+ if self.visualization:
+ #### Save using ffmpeg
+ #output_video_path = os.path.join(save_dir, '..', '{}_vis.mp4'.format(seq))
+ #cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
+ # save_dir, output_video_path)
+ #os.system(cmd_str)
+ #### Save using opencv
+ output_video_path = os.path.join(save_dir, '..', '{}_vis.avi'.format(seq))
+ imgnames = glob.glob(os.path.join(save_dir, '*.jpg'))
+ if len(imgnames) == 0:
+ logger.info('No output images to save for video')
+ return
+ img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
+ video_writer = cv2.VideoWriter(output_video_path,
+ fourcc=cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
+ fps=30,
+ frameSize=[img.shape[1], img.shape[0]])
+ for i in range(len(imgnames)):
+ imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
+ img = cv2.imread(imgpath)
+ video_writer.write(img)
+ video_writer.release()
+ logger.info('Save video in {}'.format(output_video_path))
+ print('Program Interrupted! Save video in {}'.format(output_video_path))
+ exit(0)
+
+ def add_module_config_arg(self):
+ """
+ Add the command config options.
+ """
+ self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
+
+ self.arg_config_group.add_argument('--output_dir',
+ type=str,
+ default='mot_result',
+ help='Directory name for output tracking results.')
+ self.arg_config_group.add_argument('--visualization',
+ action='store_true',
+ help="whether to save output as images.")
+ self.arg_config_group.add_argument("--draw_threshold",
+ type=float,
+ default=0.5,
+ help="Threshold to reserve the result for visualization.")
+
+ def add_module_input_arg(self):
+ """
+ Add the command input options.
+ """
+ self.arg_input_group.add_argument('--video_stream',
+ type=str,
+ help="path to video stream, can be a video file or stream device number.")
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/requirements.txt b/modules/video/multiple_object_tracking/jde_darknet53/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d2dc6ba96ad7d025f5ed9fc954f8e6030175227b
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/requirements.txt
@@ -0,0 +1,2 @@
+ppdet >= 2.1.0
+opencv-python
diff --git a/modules/video/multiple_object_tracking/jde_darknet53/tracker.py b/modules/video/multiple_object_tracking/jde_darknet53/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a984359edc18cf7e128e12e754fe201bf55fccc
--- /dev/null
+++ b/modules/video/multiple_object_tracking/jde_darknet53/tracker.py
@@ -0,0 +1,256 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import cv2
+import glob
+import paddle
+import numpy as np
+
+from ppdet.core.workspace import create
+from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
+from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
+from ppdet.modeling.mot.utils import Timer, load_det_results
+from ppdet.modeling.mot import visualization as mot_vis
+from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric
+import ppdet.utils.stats as stats
+from ppdet.engine.callbacks import Callback, ComposeCallback
+from ppdet.utils.logger import setup_logger
+
+from .dataset import MOTVideoStream, MOTImageStream
+
+logger = setup_logger(__name__)
+
+
+class StreamTracker(object):
+ def __init__(self, cfg, mode='eval'):
+ self.cfg = cfg
+ assert mode.lower() in ['test', 'eval'], \
+ "mode should be 'test' or 'eval'"
+ self.mode = mode.lower()
+ self.optimizer = None
+
+ # build model
+ self.model = create(cfg.architecture)
+
+ self.status = {}
+ self.start_epoch = 0
+
+ def load_weights_jde(self, weights):
+ load_weight(self.model, weights, self.optimizer)
+
+ def _eval_seq_jde(self, dataloader, save_dir=None, show_image=False, frame_rate=30, draw_threshold=0):
+ if save_dir:
+ if not os.path.exists(save_dir): os.makedirs(save_dir)
+ tracker = self.model.tracker
+ tracker.max_time_lost = int(frame_rate / 30.0 * tracker.track_buffer)
+
+ timer = Timer()
+ results = []
+ frame_id = 0
+ self.status['mode'] = 'track'
+ self.model.eval()
+ for step_id, data in enumerate(dataloader):
+ #print('data', data)
+ self.status['step_id'] = step_id
+ if frame_id % 40 == 0:
+ logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
+
+ # forward
+ timer.tic()
+ pred_dets, pred_embs = self.model(data)
+ online_targets = self.model.tracker.update(pred_dets, pred_embs)
+
+ online_tlwhs, online_ids = [], []
+ online_scores = []
+ for t in online_targets:
+ tlwh = t.tlwh
+ tid = t.track_id
+ tscore = t.score
+ if tscore < draw_threshold: continue
+ vertical = tlwh[2] / tlwh[3] > 1.6
+ if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
+ online_tlwhs.append(tlwh)
+ online_ids.append(tid)
+ online_scores.append(tscore)
+ timer.toc()
+
+ # save results
+ results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
+ self.save_results(data, frame_id, online_ids, online_tlwhs, online_scores, timer.average_time, show_image,
+ save_dir)
+ frame_id += 1
+
+ return results, frame_id, timer.average_time, timer.calls
+
+ def _eval_seq_jde_single_image(self, iterator, save_dir=None, show_image=False, draw_threshold=0):
+ if save_dir:
+ if not os.path.exists(save_dir): os.makedirs(save_dir)
+ tracker = self.model.tracker
+ results = []
+ frame_id = 0
+ self.status['mode'] = 'track'
+ self.model.eval()
+ timer = Timer()
+ while True:
+ try:
+ data = next(iterator)
+ timer.tic()
+ with paddle.no_grad():
+ pred_dets, pred_embs = self.model(data)
+ online_targets = self.model.tracker.update(pred_dets, pred_embs)
+
+ online_tlwhs, online_ids = [], []
+ online_scores = []
+ for t in online_targets:
+ tlwh = t.tlwh
+ tid = t.track_id
+ tscore = t.score
+ if tscore < draw_threshold: continue
+ vertical = tlwh[2] / tlwh[3] > 1.6
+ if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
+ online_tlwhs.append(tlwh)
+ online_ids.append(tid)
+ online_scores.append(tscore)
+ timer.toc()
+ # save results
+ results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
+ self.save_results(data, frame_id, online_ids, online_tlwhs, online_scores, timer.average_time,
+ show_image, save_dir)
+ frame_id += 1
+
+ yield results, frame_id
+
+ except StopIteration as e:
+ return
+
+ def imagestream_predict(self, output_dir, data_type='mot', model_type='JDE', visualization=True,
+ draw_threshold=0.5):
+ if not os.path.exists(output_dir): os.makedirs(output_dir)
+ result_root = os.path.join(output_dir, 'mot_results')
+ if not os.path.exists(result_root): os.makedirs(result_root)
+ assert data_type in ['mot', 'kitti'], \
+ "data_type should be 'mot' or 'kitti'"
+ assert model_type in ['JDE', 'FairMOT'], \
+ "model_type should be 'JDE', or 'FairMOT'"
+ seq = 'inputimages'
+ self.dataset = MOTImageStream(keep_ori_im=True)
+
+ save_dir = os.path.join(output_dir, 'mot_outputs', seq) if visualization else None
+
+ self.dataloader = create('MOTVideoStreamReader')(self.dataset, 0)
+ self.dataloader_iter = iter(self.dataloader)
+ result_filename = os.path.join(result_root, '{}.txt'.format(seq))
+
+ if model_type in ['JDE', 'FairMOT']:
+ generator = self._eval_seq_jde_single_image(
+ self.dataloader_iter, save_dir=save_dir, draw_threshold=draw_threshold)
+ else:
+ raise ValueError(model_type)
+ yield
+ results = []
+ while True:
+ with paddle.no_grad():
+ try:
+ results, nf = next(generator)
+ yield results
+ except StopIteration as e:
+ self.write_mot_results(result_filename, results, data_type)
+ return
+
+ def videostream_predict(self,
+ video_stream,
+ output_dir,
+ data_type='mot',
+ model_type='JDE',
+ visualization=True,
+ draw_threshold=0.5):
+ assert video_stream is not None, \
+ "--video_file or --image_dir should be set."
+
+ if not os.path.exists(output_dir): os.makedirs(output_dir)
+ result_root = os.path.join(output_dir, 'mot_results')
+ if not os.path.exists(result_root): os.makedirs(result_root)
+ assert data_type in ['mot', 'kitti'], \
+ "data_type should be 'mot' or 'kitti'"
+ assert model_type in ['JDE', 'FairMOT'], \
+ "model_type should be 'JDE', or 'FairMOT'"
+ seq = os.path.splitext(os.path.basename(video_stream))[0]
+ self.dataset = MOTVideoStream(video_stream, keep_ori_im=True)
+
+ save_dir = os.path.join(output_dir, 'mot_outputs', seq) if visualization else None
+
+ dataloader = create('MOTVideoStreamReader')(self.dataset, 0)
+ result_filename = os.path.join(result_root, '{}.txt'.format(seq))
+
+ with paddle.no_grad():
+ if model_type in ['JDE', 'FairMOT']:
+ results, nf, ta, tc = self._eval_seq_jde(dataloader, save_dir=save_dir, draw_threshold=draw_threshold)
+ else:
+ raise ValueError(model_type)
+
+ self.write_mot_results(result_filename, results, data_type)
+
+ if visualization:
+ #### Save using ffmpeg
+ #output_video_path = os.path.join(save_dir, '..', '{}_vis.mp4'.format(seq))
+ #cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
+ # save_dir, output_video_path)
+ #os.system(cmd_str)
+ #### Save using opencv
+ output_video_path = os.path.join(save_dir, '..', '{}_vis.avi'.format(seq))
+ imgnames = glob.glob(os.path.join(save_dir, '*.jpg'))
+ if len(imgnames) == 0:
+ logger.info('No output images to save for video')
+ return
+ img = cv2.imread(os.path.join(save_dir, '00000.jpg'))
+ video_writer = cv2.VideoWriter(output_video_path, fourcc=cv2.VideoWriter_fourcc('M','J','P','G'), fps=30, frameSize=[img.shape[1],img.shape[0]])
+ for i in range(len(imgnames)):
+ imgpath = os.path.join(save_dir, '{:05d}.jpg'.format(i))
+ img = cv2.imread(imgpath)
+ video_writer.write(img)
+ video_writer.release()
+ logger.info('Save video in {}'.format(output_video_path))
+
+ def write_mot_results(self, filename, results, data_type='mot'):
+ if data_type in ['mot', 'mcmot', 'lab']:
+ save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
+ elif data_type == 'kitti':
+ save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
+ else:
+ raise ValueError(data_type)
+
+ with open(filename, 'w') as f:
+ for frame_id, tlwhs, tscores, track_ids in results:
+ if data_type == 'kitti':
+ frame_id -= 1
+ for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
+ if track_id < 0:
+ continue
+ x1, y1, w, h = tlwh
+ x2, y2 = x1 + w, y1 + h
+ line = save_format.format(
+ frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=score)
+ f.write(line)
+ logger.info('MOT results save in {}'.format(filename))
+
+ def save_results(self, data, frame_id, online_ids, online_tlwhs, online_scores, average_time, show_image, save_dir):
+ if show_image or save_dir is not None:
+ assert 'ori_image' in data
+ img0 = data['ori_image'].numpy()[0]
+ online_im = mot_vis.plot_tracking(
+ img0, online_tlwhs, online_ids, online_scores, frame_id=frame_id, fps=1. / average_time)
+ if show_image:
+ cv2.imshow('online_im', online_im)
+ if save_dir is not None:
+ cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)