未验证 提交 4caab055 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] add drawing of pptracking and pphuman (#5456)

* add entrance_count draw_center_traj save_mot_txts

* fix pphuman plot_tracking_dict

* fix pphuman plot_tracking_dict

* fix plot tracking dict

* fix pipeline mot plot, add gif

* fix collector append, test=document_fix
上级 49379331
...@@ -39,7 +39,7 @@ PP-Human提供了目标检测、属性识别、行为识别、ReID预训练模 ...@@ -39,7 +39,7 @@ PP-Human提供了目标检测、属性识别、行为识别、ReID预训练模
| 任务 | 适用场景 | 精度 | 预测速度(FPS) | 预测部署模型 | | 任务 | 适用场景 | 精度 | 预测速度(FPS) | 预测部署模型 |
| :---------: |:---------: |:--------------- | :-------: | :------: | | :---------: |:---------: |:--------------- | :-------: | :------: |
| 目标检测 | 图片/视频输入 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip) | | 目标检测 | 图片/视频输入 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip) |
| 属性识别 | 图片/视频输入 属性识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/strongbaseline_r50_30e_pa100k.tar) | | 属性识别 | 图片/视频输入 属性识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/strongbaseline_r50_30e_pa100k.zip) |
| 关键点检测 | 视频输入 行为识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/dark_hrnet_w32_256x192.zip) | 关键点检测 | 视频输入 行为识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/dark_hrnet_w32_256x192.zip)
| 行为识别 | 视频输入 行为识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/STGCN.zip) | | 行为识别 | 视频输入 行为识别 | - | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/STGCN.zip) |
| ReID | 视频输入 跨镜跟踪 | - | - | [下载链接]() | | ReID | 视频输入 跨镜跟踪 | - | - | [下载链接]() |
...@@ -117,7 +117,8 @@ python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml - ...@@ -117,7 +117,8 @@ python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml -
| --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False | | --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False |
| --cpu_threads | Option| 设置cpu线程数,默认为1 | | --cpu_threads | Option| 设置cpu线程数,默认为1 |
| --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False | | --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False |
| --do_entrance_counting | Option | 是否统计出入口流量,默认为False |
| --draw_center_traj | Option | 是否绘制跟踪轨迹,默认为False |
## 三、方案介绍 ## 三、方案介绍
...@@ -130,13 +131,13 @@ PP-Human整体方案如下图所示 ...@@ -130,13 +131,13 @@ PP-Human整体方案如下图所示
### 1. 目标检测 ### 1. 目标检测
- 采用PP-YOLOE L 作为目标检测模型 - 采用PP-YOLOE L 作为目标检测模型
- 详细文档参考[PP-YOLOE](../../configs/ppyoloe/) - 详细文档参考[PP-YOLOE](../../configs/ppyoloe/)[检测跟踪文档](docs/mot.md)
### 2. 多目标跟踪 ### 2. 多目标跟踪
- 采用SDE方案完成多目标跟踪 - 采用SDE方案完成多目标跟踪
- 检测模型使用PP-YOLOE L - 检测模型使用PP-YOLOE L
- 跟踪模块采用Bytetrack方案 - 跟踪模块采用Bytetrack方案
- 详细文档参考[Bytetrack](configs/mot/bytetrack) - 详细文档参考[Bytetrack](../../configs/mot/bytetrack)[检测跟踪文档](docs/mot.md)
### 3. 跨镜跟踪 ### 3. 跨镜跟踪
- 使用PP-YOLOE + Bytetrack得到单镜头多目标跟踪轨迹 - 使用PP-YOLOE + Bytetrack得到单镜头多目标跟踪轨迹
......
...@@ -5,7 +5,7 @@ visual: True ...@@ -5,7 +5,7 @@ visual: True
warmup_frame: 50 warmup_frame: 50
DET: DET:
model_dir: output_inference/mot_ppyolov3/ model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/
batch_size: 1 batch_size: 1
ATTR: ATTR:
...@@ -13,7 +13,7 @@ ATTR: ...@@ -13,7 +13,7 @@ ATTR:
batch_size: 8 batch_size: 8
MOT: MOT:
model_dir: output_inference/mot_ppyolov3/ model_dir: output_inference/mot_ppyoloe_l_36e_pipeline/
tracker_config: deploy/pphuman/config/tracker_config.yml tracker_config: deploy/pphuman/config/tracker_config.yml
batch_size: 1 batch_size: 1
......
因为 它太大了无法显示 image diff 。你可以改为 查看blob
# PP-Human检测跟踪模块
行人检测与跟踪在智慧社区,工业巡检,交通监控等方向都具有广泛应用,PP-Human中集成了检测跟踪模块,是关键点检测、属性行为识别等任务的基础。我们提供了预训练模型,用户可以直接下载使用。
| 任务 | 算法 | 精度 | 预测速度(ms) |下载链接 |
|:---------------------|:---------:|:------:|:------:| :---------------------------------------------------------------------------------: |
| 行人检测/跟踪 | PP-YOLOE | mAP: 56.3 <br> MOTA: 72.0 | 检测: 28ms <br> 跟踪:33.1ms | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip) |
1. 检测/跟踪模型精度为MOT17,CrowdHuman,HIEVE和部分业务数据融合训练测试得到
2. 预测速度为T4 机器上使用TensorRT FP16时的速度
## 使用方法
1. 从上表链接中下载模型并解压到```./output_inference```路径下
2. 图片输入时,启动命令如下
```python
python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml \
--image_file=test_image.jpg \
--device=gpu
```
3. 视频输入时,启动命令如下
```python
python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml \
--video_file=test_video.mp4 \
--device=gpu
```
4. 若修改模型路径,有以下两种方式:
- ```./deploy/pphuman/config/infer_cfg.yml```下可以配置不同模型路径,检测和跟踪模型分别对应`DET`和`MOT`字段,修改对应字段下的路径为实际期望的路径即可。
- 命令行中增加`--model_dir`修改模型路径:
```python
python deploy/pphuman/pipeline.py --config deploy/pphuman/config/infer_cfg.yml \
--video_file=test_video.mp4 \
--device=gpu \
--model_dir det=ppyoloe/
--do_entrance_counting \
--draw_center_traj
```
**注意:**
- `--do_entrance_counting`表示是否统计出入口流量,不设置即默认为False,`--draw_center_traj`表示是否绘制跟踪轨迹,不设置即默认为False。注意绘制跟踪轨迹的测试视频最好是静止摄像头拍摄的。
测试效果如下:
<div width="1000" align="center">
<img src="./images/mot.gif"/>
</div>
数据来源及版权归属:天覆科技,感谢提供并开源实际场景数据,仅限学术研究使用
## 方案说明
1. 目标检测/多目标跟踪获取图片/视频输入中的行人检测框,模型方案为PP-YOLOE,详细文档参考[PP-YOLOE](../../../configs/ppyoloe)
2. 多目标跟踪模型方案基于[ByteTrack](https://arxiv.org/pdf/2110.06864.pdf),采用PP-YOLOE替换原文的YOLOX作为检测器,采用BYTETracker作为跟踪器。
## 参考文献
```
@article{zhang2021bytetrack,
title={ByteTrack: Multi-Object Tracking by Associating Every Detection Box},
author={Zhang, Yifu and Sun, Peize and Jiang, Yi and Yu, Dongdong and Yuan, Zehuan and Luo, Ping and Liu, Wenyu and Wang, Xinggang},
journal={arXiv preprint arXiv:2110.06864},
year={2021}
}
```
...@@ -108,6 +108,21 @@ def argsparser(): ...@@ -108,6 +108,21 @@ def argsparser():
default=False, default=False,
help="If the model is produced by TRT offline quantitative " help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.") "calibration, trt_calib_mode need to set True.")
parser.add_argument(
"--do_entrance_counting",
action='store_true',
help="Whether counting the numbers of identifiers entering "
"or getting out from the entrance. Note that only support one-class"
"counting, multi-class counting is coming soon.")
parser.add_argument(
"--secs_interval",
type=int,
default=2,
help="The seconds interval to count after tracking")
parser.add_argument(
"--draw_center_traj",
action='store_true',
help="Whether drawing the trajectory of center")
return parser return parser
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import yaml import yaml
import glob import glob
from collections import defaultdict
import cv2 import cv2
import numpy as np import numpy as np
...@@ -44,7 +45,8 @@ from python.preprocess import decode_image ...@@ -44,7 +45,8 @@ from python.preprocess import decode_image
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action
from pptracking.python.mot_sde_infer import SDE_Detector from pptracking.python.mot_sde_infer import SDE_Detector
from pptracking.python.mot.visualize import plot_tracking from pptracking.python.mot.visualize import plot_tracking_dict
from pptracking.python.mot.utils import flow_statistic
class Pipeline(object): class Pipeline(object):
...@@ -72,6 +74,11 @@ class Pipeline(object): ...@@ -72,6 +74,11 @@ class Pipeline(object):
cpu_threads (int): cpu threads, default as 1 cpu_threads (int): cpu threads, default as 1
enable_mkldnn (bool): whether to open MKLDNN, default as False enable_mkldnn (bool): whether to open MKLDNN, default as False
output_dir (string): The path of output, default as 'output' output_dir (string): The path of output, default as 'output'
draw_center_traj (bool): Whether drawing the trajectory of center, default as False
secs_interval (int): The seconds interval to count after tracking, default as 10
do_entrance_counting(bool): Whether counting the numbers of identifiers entering
or getting out from the entrance, default as False,only support single class
counting in MOT.
""" """
def __init__(self, def __init__(self,
...@@ -91,7 +98,10 @@ class Pipeline(object): ...@@ -91,7 +98,10 @@ class Pipeline(object):
trt_calib_mode=False, trt_calib_mode=False,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output'): output_dir='output',
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False):
self.multi_camera = False self.multi_camera = False
self.is_video = False self.is_video = False
self.output_dir = output_dir self.output_dir = output_dir
...@@ -129,10 +139,18 @@ class Pipeline(object): ...@@ -129,10 +139,18 @@ class Pipeline(object):
trt_calib_mode=trt_calib_mode, trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn, enable_mkldnn=enable_mkldnn,
output_dir=output_dir) output_dir=output_dir,
draw_center_traj=draw_center_traj,
secs_interval=secs_interval,
do_entrance_counting=do_entrance_counting)
if self.is_video: if self.is_video:
self.predictor.set_file_name(video_file) self.predictor.set_file_name(video_file)
self.output_dir = output_dir
self.draw_center_traj = draw_center_traj
self.secs_interval = secs_interval
self.do_entrance_counting = do_entrance_counting
def _parse_input(self, image_file, image_dir, video_file, video_dir, def _parse_input(self, image_file, image_dir, video_file, video_dir,
camera_id): camera_id):
...@@ -144,6 +162,7 @@ class Pipeline(object): ...@@ -144,6 +162,7 @@ class Pipeline(object):
self.multi_camera = False self.multi_camera = False
elif video_file is not None: elif video_file is not None:
assert os.path.exists(video_file), "video_file not exists."
self.multi_camera = False self.multi_camera = False
input = video_file input = video_file
self.is_video = True self.is_video = True
...@@ -222,6 +241,11 @@ class PipePredictor(object): ...@@ -222,6 +241,11 @@ class PipePredictor(object):
cpu_threads (int): cpu threads, default as 1 cpu_threads (int): cpu threads, default as 1
enable_mkldnn (bool): whether to open MKLDNN, default as False enable_mkldnn (bool): whether to open MKLDNN, default as False
output_dir (string): The path of output, default as 'output' output_dir (string): The path of output, default as 'output'
draw_center_traj (bool): Whether drawing the trajectory of center, default as False
secs_interval (int): The seconds interval to count after tracking, default as 10
do_entrance_counting(bool): Whether counting the numbers of identifiers entering
or getting out from the entrance, default as False,only support single class
counting in MOT.
""" """
def __init__(self, def __init__(self,
...@@ -238,7 +262,10 @@ class PipePredictor(object): ...@@ -238,7 +262,10 @@ class PipePredictor(object):
trt_calib_mode=False, trt_calib_mode=False,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output'): output_dir='output',
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False):
if enable_attr and not cfg.get('ATTR', False): if enable_attr and not cfg.get('ATTR', False):
ValueError( ValueError(
...@@ -268,6 +295,9 @@ class PipePredictor(object): ...@@ -268,6 +295,9 @@ class PipePredictor(object):
self.multi_camera = multi_camera self.multi_camera = multi_camera
self.cfg = cfg self.cfg = cfg
self.output_dir = output_dir self.output_dir = output_dir
self.draw_center_traj = draw_center_traj
self.secs_interval = secs_interval
self.do_entrance_counting = do_entrance_counting
self.warmup_frame = self.cfg['warmup_frame'] self.warmup_frame = self.cfg['warmup_frame']
self.pipeline_res = Result() self.pipeline_res = Result()
...@@ -298,9 +328,20 @@ class PipePredictor(object): ...@@ -298,9 +328,20 @@ class PipePredictor(object):
tracker_config = mot_cfg['tracker_config'] tracker_config = mot_cfg['tracker_config']
batch_size = mot_cfg['batch_size'] batch_size = mot_cfg['batch_size']
self.mot_predictor = SDE_Detector( self.mot_predictor = SDE_Detector(
model_dir, tracker_config, device, run_mode, batch_size, model_dir,
trt_min_shape, trt_max_shape, trt_opt_shape, trt_calib_mode, tracker_config,
cpu_threads, enable_mkldnn) device,
run_mode,
batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
draw_center_traj=draw_center_traj,
secs_interval=secs_interval,
do_entrance_counting=do_entrance_counting)
if self.with_attr: if self.with_attr:
attr_cfg = self.cfg['ATTR'] attr_cfg = self.cfg['ATTR']
model_dir = attr_cfg['model_dir'] model_dir = attr_cfg['model_dir']
...@@ -431,6 +472,7 @@ class PipePredictor(object): ...@@ -431,6 +472,7 @@ class PipePredictor(object):
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS)) fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("video fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
...@@ -438,6 +480,19 @@ class PipePredictor(object): ...@@ -438,6 +480,19 @@ class PipePredictor(object):
fourcc = cv2.VideoWriter_fourcc(* 'mp4v') fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
entrance, records, center_traj = None, None, None
if self.draw_center_traj:
center_traj = [{}]
id_set = set()
interval_id_set = set()
in_id_list = list()
out_id_list = list()
prev_center = dict()
records = list()
entrance = [0, height / 2., width, height / 2.]
video_fps = fps
while (1): while (1):
if frame_id % 10 == 0: if frame_id % 10 == 0:
print('frame id: ', frame_id) print('frame id: ', frame_id)
...@@ -457,6 +512,16 @@ class PipePredictor(object): ...@@ -457,6 +512,16 @@ class PipePredictor(object):
# mot output format: id, class, score, xmin, ymin, xmax, ymax # mot output format: id, class, score, xmin, ymin, xmax, ymax
mot_res = parse_mot_res(res) mot_res = parse_mot_res(res)
# flow_statistic only support single class MOT
boxes, scores, ids = res[0] # batch size = 1 in MOT
mot_result = (frame_id + 1, boxes[0], scores[0],
ids[0]) # single class
statistic = flow_statistic(
mot_result, self.secs_interval, self.do_entrance_counting,
video_fps, entrance, id_set, interval_id_set, in_id_list,
out_id_list, prev_center, records)
records = statistic['records']
# nothing detected # nothing detected
if len(mot_res['boxes']) == 0: if len(mot_res['boxes']) == 0:
frame_id += 1 frame_id += 1
...@@ -549,13 +614,21 @@ class PipePredictor(object): ...@@ -549,13 +614,21 @@ class PipePredictor(object):
if self.cfg['visual']: if self.cfg['visual']:
_, _, fps = self.pipe_timer.get_total_time() _, _, fps = self.pipe_timer.get_total_time()
im = self.visualize_video(frame, self.pipeline_res, frame_id, im = self.visualize_video(frame, self.pipeline_res, frame_id,
fps) # visualize fps, entrance, records,
center_traj) # visualize
writer.write(im) writer.write(im)
writer.release() writer.release()
print('save result to {}'.format(out_path)) print('save result to {}'.format(out_path))
def visualize_video(self, image, result, frame_id, fps): def visualize_video(self,
image,
result,
frame_id,
fps,
entrance=None,
records=None,
center_traj=None):
mot_res = copy.deepcopy(result.get('mot')) mot_res = copy.deepcopy(result.get('mot'))
if mot_res is not None: if mot_res is not None:
ids = mot_res['boxes'][:, 0] ids = mot_res['boxes'][:, 0]
...@@ -567,8 +640,28 @@ class PipePredictor(object): ...@@ -567,8 +640,28 @@ class PipePredictor(object):
boxes = np.zeros([0, 4]) boxes = np.zeros([0, 4])
ids = np.zeros([0]) ids = np.zeros([0])
scores = np.zeros([0]) scores = np.zeros([0])
image = plot_tracking(
image, boxes, ids, scores, frame_id=frame_id, fps=fps) # single class, still need to be defaultdict type for ploting
num_classes = 1
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
online_tlwhs[0] = boxes
online_scores[0] = scores
online_ids[0] = ids
image = plot_tracking_dict(
image,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
do_entrance_counting=self.do_entrance_counting,
entrance=entrance,
records=records,
center_traj=center_traj)
attr_res = result.get('attr') attr_res = result.get('attr')
if attr_res is not None: if attr_res is not None:
...@@ -630,7 +723,8 @@ def main(): ...@@ -630,7 +723,8 @@ def main():
FLAGS.video_dir, FLAGS.camera_id, FLAGS.enable_attr, FLAGS.video_dir, FLAGS.camera_id, FLAGS.enable_attr,
FLAGS.enable_action, FLAGS.device, FLAGS.run_mode, FLAGS.trt_min_shape, FLAGS.enable_action, FLAGS.device, FLAGS.run_mode, FLAGS.trt_min_shape,
FLAGS.trt_max_shape, FLAGS.trt_opt_shape, FLAGS.trt_calib_mode, FLAGS.trt_max_shape, FLAGS.trt_opt_shape, FLAGS.trt_calib_mode,
FLAGS.cpu_threads, FLAGS.enable_mkldnn, FLAGS.output_dir) FLAGS.cpu_threads, FLAGS.enable_mkldnn, FLAGS.output_dir,
FLAGS.draw_center_traj, FLAGS.secs_interval, FLAGS.do_entrance_counting)
pipeline.run() pipeline.run()
......
...@@ -35,10 +35,21 @@ wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4 ...@@ -35,10 +35,21 @@ wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4
# Python预测视频 # Python预测视频
python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video_file=mot17_demo.mp4 --device=GPU --threshold=0.5 --save_mot_txts --save_images python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video_file=mot17_demo.mp4 --device=GPU --threshold=0.5 --save_mot_txts --save_images
``` ```
### 1.3 用导出的模型基于Python去预测,以及进行流量计数、出入口统计和绘制跟踪轨迹等
```bash
# 下载出入口统计demo视频:
wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/entrance_count_demo.mp4
# Python预测视频
python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video_file=entrance_count_demo.mp4 --device=GPU --do_entrance_counting --draw_center_traj
```
**注意:** **注意:**
- 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 - 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
- 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1` - 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`
- `--threshold`表示结果可视化的置信度阈值,默认为0.5,低于该阈值的结果会被过滤掉,为了可视化效果更佳,可根据实际情况自行修改。 - `--threshold`表示结果可视化的置信度阈值,默认为0.5,低于该阈值的结果会被过滤掉,为了可视化效果更佳,可根据实际情况自行修改。
- `--do_entrance_counting`表示是否统计出入口流量,默认为False,`--draw_center_traj`表示是否绘制跟踪轨迹,默认为False。注意绘制跟踪轨迹的测试视频最好是静止摄像头拍摄的。
- 对于多类别或车辆的FairMOT模型的导出和Python预测只需更改相应的config和模型权重即可。如: - 对于多类别或车辆的FairMOT模型的导出和Python预测只需更改相应的config和模型权重即可。如:
```bash ```bash
job_name=mcfairmot_hrnetv2_w18_dlafpn_30e_576x320_visdrone job_name=mcfairmot_hrnetv2_w18_dlafpn_30e_576x320_visdrone
......
...@@ -27,7 +27,7 @@ from .utils import parse_pt_gt, parse_pt, compare_dataframes_mtmc ...@@ -27,7 +27,7 @@ from .utils import parse_pt_gt, parse_pt, compare_dataframes_mtmc
from .utils import get_labels, getData, gen_new_mot from .utils import get_labels, getData, gen_new_mot
from .camera_utils import get_labels_with_camera from .camera_utils import get_labels_with_camera
from .zone import Zone from .zone import Zone
from ..utils import plot_tracking from ..visualize import plot_tracking
__all__ = [ __all__ = [
'trajectory_fusion', 'trajectory_fusion',
...@@ -68,8 +68,8 @@ def trajectory_fusion(mot_feature, cid, cid_bias, use_zone=False, zone_path=''): ...@@ -68,8 +68,8 @@ def trajectory_fusion(mot_feature, cid, cid_bias, use_zone=False, zone_path=''):
zone_list = [tracklet[f]['zone'] for f in frame_list] zone_list = [tracklet[f]['zone'] for f in frame_list]
feature_list = [ feature_list = [
tracklet[f]['feat'] for f in frame_list tracklet[f]['feat'] for f in frame_list
if (tracklet[f]['bbox'][3] - tracklet[f]['bbox'][1] if (tracklet[f]['bbox'][3] - tracklet[f]['bbox'][1]) *
) * (tracklet[f]['bbox'][2] - tracklet[f]['bbox'][0]) > 2000 (tracklet[f]['bbox'][2] - tracklet[f]['bbox'][0]) > 2000
] ]
if len(feature_list) < 2: if len(feature_list) < 2:
feature_list = [tracklet[f]['feat'] for f in frame_list] feature_list = [tracklet[f]['feat'] for f in frame_list]
......
...@@ -20,8 +20,7 @@ import collections ...@@ -20,8 +20,7 @@ import collections
__all__ = [ __all__ = [
'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results', 'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results',
'preprocess_reid', 'get_crops', 'clip_box', 'scale_coords', 'flow_statistic', 'preprocess_reid', 'get_crops', 'clip_box', 'scale_coords', 'flow_statistic'
'plot_tracking'
] ]
...@@ -197,10 +196,7 @@ def preprocess_reid(imgs, ...@@ -197,10 +196,7 @@ def preprocess_reid(imgs,
std=[0.229, 0.224, 0.225]): std=[0.229, 0.224, 0.225]):
im_batch = [] im_batch = []
for img in imgs: for img in imgs:
try:
img = cv2.resize(img, (w, h)) img = cv2.resize(img, (w, h))
except:
embed()
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1)) img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1)) img_std = np.array(std).reshape((3, 1, 1))
...@@ -288,77 +284,3 @@ def flow_statistic(result, ...@@ -288,77 +284,3 @@ def flow_statistic(result,
"prev_center": prev_center, "prev_center": prev_center,
"records": records "records": records
} }
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def plot_tracking(image,
tlwhs,
obj_ids,
scores=None,
frame_id=0,
fps=0.,
ids2names=[],
do_entrance_counting=False,
entrance=None):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
if fps > 0:
_line = 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs))
else:
_line = 'frame: %d num: %d' % (frame_id, len(tlwhs))
cv2.putText(
im,
_line,
(0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id))
if ids2names != []:
assert len(
ids2names) == 1, "plot_tracking only supports single classes."
id_text = '{}_'.format(ids2names[0]) + id_text
_line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id))
cv2.rectangle(
im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText(
im,
id_text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=text_thickness)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
if do_entrance_counting:
entrance_line = tuple(map(int, entrance))
cv2.rectangle(
im,
entrance_line[0:2],
entrance_line[2:4],
color=(0, 255, 255),
thickness=line_thickness)
return im
...@@ -31,7 +31,7 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) ...@@ -31,7 +31,7 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from mot import JDETracker from mot import JDETracker
from mot.utils import MOTTimer, write_mot_results from mot.utils import MOTTimer, write_mot_results, flow_statistic
from mot.visualize import plot_tracking, plot_tracking_dict from mot.visualize import plot_tracking, plot_tracking_dict
# Global dictionary # Global dictionary
...@@ -55,9 +55,19 @@ class JDE_Detector(Detector): ...@@ -55,9 +55,19 @@ class JDE_Detector(Detector):
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
draw_center_traj (bool): Whether drawing the trajectory of center, default as False
secs_interval (int): The seconds interval to count after tracking, default as 10
do_entrance_counting(bool): Whether counting the numbers of identifiers entering
or getting out from the entrance, default as False,only support single class
counting in MOT.
""" """
def __init__(self, def __init__(
self,
model_dir, model_dir,
tracker_config=None, tracker_config=None,
device='CPU', device='CPU',
...@@ -70,7 +80,12 @@ class JDE_Detector(Detector): ...@@ -70,7 +80,12 @@ class JDE_Detector(Detector):
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output', output_dir='output',
threshold=0.5): threshold=0.5,
save_images=False,
save_mot_txts=False,
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False, ):
super(JDE_Detector, self).__init__( super(JDE_Detector, self).__init__(
model_dir=model_dir, model_dir=model_dir,
device=device, device=device,
...@@ -84,6 +99,12 @@ class JDE_Detector(Detector): ...@@ -84,6 +99,12 @@ class JDE_Detector(Detector):
enable_mkldnn=enable_mkldnn, enable_mkldnn=enable_mkldnn,
output_dir=output_dir, output_dir=output_dir,
threshold=threshold, ) threshold=threshold, )
self.save_images = save_images
self.save_mot_txts = save_mot_txts
self.draw_center_traj = draw_center_traj
self.secs_interval = secs_interval
self.do_entrance_counting = do_entrance_counting
assert batch_size == 1, "MOT model only supports batch_size=1." assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True) self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels) self.num_classes = len(self.pred_config.labels)
...@@ -115,7 +136,7 @@ class JDE_Detector(Detector): ...@@ -115,7 +136,7 @@ class JDE_Detector(Detector):
return result return result
def tracking(self, det_results): def tracking(self, det_results):
pred_dets = det_results['pred_dets'] pred_dets = det_results['pred_dets'] # cls_id, score, x0, y0, x1, y1
pred_embs = det_results['pred_embs'] pred_embs = det_results['pred_embs']
online_targets_dict = self.tracker.update(pred_dets, pred_embs) online_targets_dict = self.tracker.update(pred_dets, pred_embs)
...@@ -164,7 +185,8 @@ class JDE_Detector(Detector): ...@@ -164,7 +185,8 @@ class JDE_Detector(Detector):
image_list, image_list,
run_benchmark=False, run_benchmark=False,
repeats=1, repeats=1,
visual=True): visual=True,
seq_name=None):
mot_results = [] mot_results = []
num_classes = self.num_classes num_classes = self.num_classes
image_list.sort() image_list.sort()
...@@ -225,7 +247,7 @@ class JDE_Detector(Detector): ...@@ -225,7 +247,7 @@ class JDE_Detector(Detector):
self.det_times.img_num += 1 self.det_times.img_num += 1
if visual: if visual:
if frame_id % 10 == 0: if len(image_list) > 1 and frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id)) print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {}) frame, _ = decode_image(img_file, {})
...@@ -237,6 +259,7 @@ class JDE_Detector(Detector): ...@@ -237,6 +259,7 @@ class JDE_Detector(Detector):
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
ids2names=ids2names) ids2names=ids2names)
if seq_name is None:
seq_name = image_list[0].split('/')[-2] seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name) save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
...@@ -264,7 +287,8 @@ class JDE_Detector(Detector): ...@@ -264,7 +287,8 @@ class JDE_Detector(Detector):
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name) out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v') video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1 frame_id = 1
...@@ -273,6 +297,23 @@ class JDE_Detector(Detector): ...@@ -273,6 +297,23 @@ class JDE_Detector(Detector):
num_classes = self.num_classes num_classes = self.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot' data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = self.pred_config.labels ids2names = self.pred_config.labels
center_traj = None
entrance = None
records = None
if self.draw_center_traj:
center_traj = [{} for i in range(num_classes)]
if num_classes == 1:
id_set = set()
interval_id_set = set()
in_id_list = list()
out_id_list = list()
prev_center = dict()
records = list()
entrance = [0, height / 2., width, height / 2.]
video_fps = fps
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
...@@ -282,7 +323,9 @@ class JDE_Detector(Detector): ...@@ -282,7 +323,9 @@ class JDE_Detector(Detector):
frame_id += 1 frame_id += 1
timer.tic() timer.tic()
mot_results = self.predict_image([frame], visual=False) seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image(
[frame], visual=False, seq_name=seq_name)
timer.toc() timer.toc()
online_tlwhs, online_scores, online_ids = mot_results[0] online_tlwhs, online_scores, online_ids = mot_results[0]
...@@ -291,6 +334,16 @@ class JDE_Detector(Detector): ...@@ -291,6 +334,16 @@ class JDE_Detector(Detector):
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id], (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id])) online_ids[cls_id]))
# NOTE: just implement flow statistic for single class
if num_classes == 1:
result = (frame_id + 1, online_tlwhs[0], online_scores[0],
online_ids[0])
statistic = flow_statistic(
result, self.secs_interval, self.do_entrance_counting,
video_fps, entrance, id_set, interval_id_set, in_id_list,
out_id_list, prev_center, records, data_type, num_classes)
records = statistic['records']
fps = 1. / timer.duration fps = 1. / timer.duration
im = plot_tracking_dict( im = plot_tracking_dict(
frame, frame,
...@@ -300,27 +353,57 @@ class JDE_Detector(Detector): ...@@ -300,27 +353,57 @@ class JDE_Detector(Detector):
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=fps, fps=fps,
ids2names=ids2names) ids2names=ids2names,
do_entrance_counting=self.do_entrance_counting,
entrance=entrance,
records=records,
center_traj=center_traj)
writer.write(im) writer.write(im)
if camera_id != -1: if camera_id != -1:
cv2.imshow('Mask Detection', im) cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes)
if num_classes == 1:
result_filename = os.path.join(
self.output_dir,
video_out_name.split('.')[-2] + '_flow_statistic.txt')
f = open(result_filename, 'w')
for line in records:
f.write(line)
print('Flow statistic save in {}'.format(result_filename))
f.close()
writer.release() writer.release()
def main(): def main():
detector = JDE_Detector( detector = JDE_Detector(
FLAGS.model_dir, FLAGS.model_dir,
tracker_config=None,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=1,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape, trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn) enable_mkldnn=FLAGS.enable_mkldnn,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts,
draw_center_traj=FLAGS.draw_center_traj,
secs_interval=FLAGS.secs_interval,
do_entrance_counting=FLAGS.do_entrance_counting, )
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......
...@@ -33,7 +33,7 @@ sys.path.insert(0, parent_path) ...@@ -33,7 +33,7 @@ sys.path.insert(0, parent_path)
from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
from mot_utils import argsparser, Timer, get_current_memory_mb, video2frames, _is_valid_video from mot_utils import argsparser, Timer, get_current_memory_mb, video2frames, _is_valid_video
from mot.tracker import JDETracker, DeepSORTTracker from mot.tracker import JDETracker, DeepSORTTracker
from mot.utils import MOTTimer, write_mot_results, flow_statistic, get_crops, clip_box from mot.utils import MOTTimer, write_mot_results, get_crops, clip_box, flow_statistic
from mot.visualize import plot_tracking, plot_tracking_dict from mot.visualize import plot_tracking, plot_tracking_dict
from mot.mtmct.utils import parse_bias from mot.mtmct.utils import parse_bias
...@@ -56,6 +56,15 @@ class SDE_Detector(Detector): ...@@ -56,6 +56,15 @@ class SDE_Detector(Detector):
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
draw_center_traj (bool): Whether drawing the trajectory of center, default as False
secs_interval (int): The seconds interval to count after tracking, default as 10
do_entrance_counting(bool): Whether counting the numbers of identifiers entering
or getting out from the entrance, default as False,only support single class
counting in MOT.
reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
mtmct_dir (str): MTMCT dir, default None, set for doing MTMCT mtmct_dir (str): MTMCT dir, default None, set for doing MTMCT
""" """
...@@ -74,6 +83,11 @@ class SDE_Detector(Detector): ...@@ -74,6 +83,11 @@ class SDE_Detector(Detector):
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output', output_dir='output',
threshold=0.5, threshold=0.5,
save_images=False,
save_mot_txts=False,
draw_center_traj=False,
secs_interval=10,
do_entrance_counting=False,
reid_model_dir=None, reid_model_dir=None,
mtmct_dir=None): mtmct_dir=None):
super(SDE_Detector, self).__init__( super(SDE_Detector, self).__init__(
...@@ -89,6 +103,12 @@ class SDE_Detector(Detector): ...@@ -89,6 +103,12 @@ class SDE_Detector(Detector):
enable_mkldnn=enable_mkldnn, enable_mkldnn=enable_mkldnn,
output_dir=output_dir, output_dir=output_dir,
threshold=threshold, ) threshold=threshold, )
self.save_images = save_images
self.save_mot_txts = save_mot_txts
self.draw_center_traj = draw_center_traj
self.secs_interval = secs_interval
self.do_entrance_counting = do_entrance_counting
assert batch_size == 1, "MOT model only supports batch_size=1." assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True) self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels) self.num_classes = len(self.pred_config.labels)
...@@ -309,6 +329,7 @@ class SDE_Detector(Detector): ...@@ -309,6 +329,7 @@ class SDE_Detector(Detector):
feat_data['feat'] = _feat feat_data['feat'] = _feat
tracking_outs['feat_data'].update({_imgname: feat_data}) tracking_outs['feat_data'].update({_imgname: feat_data})
return tracking_outs return tracking_outs
else: else:
tracking_outs = { tracking_outs = {
'online_tlwhs': online_tlwhs, 'online_tlwhs': online_tlwhs,
...@@ -409,7 +430,7 @@ class SDE_Detector(Detector): ...@@ -409,7 +430,7 @@ class SDE_Detector(Detector):
mot_results.append([online_tlwhs, online_scores, online_ids]) mot_results.append([online_tlwhs, online_scores, online_ids])
if visual: if visual:
if frame_id % 10 == 0: if len(image_list) > 1 and frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id)) print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {}) frame, _ = decode_image(img_file, {})
if isinstance(online_tlwhs, defaultdict): if isinstance(online_tlwhs, defaultdict):
...@@ -456,13 +477,32 @@ class SDE_Detector(Detector): ...@@ -456,13 +477,32 @@ class SDE_Detector(Detector):
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name) out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v') video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1 frame_id = 1
timer = MOTTimer() timer = MOTTimer()
results = defaultdict(list) # support single class and multi classes results = defaultdict(list)
num_classes = self.num_classes num_classes = self.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = self.pred_config.labels
center_traj = None
entrance = None
records = None
if self.draw_center_traj:
center_traj = [{} for i in range(num_classes)]
if num_classes == 1:
id_set = set()
interval_id_set = set()
in_id_list = list()
out_id_list = list()
prev_center = dict()
records = list()
entrance = [0, height / 2., width, height / 2.]
video_fps = fps
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
...@@ -477,10 +517,21 @@ class SDE_Detector(Detector): ...@@ -477,10 +517,21 @@ class SDE_Detector(Detector):
[frame], visual=False, seq_name=seq_name) [frame], visual=False, seq_name=seq_name)
timer.toc() timer.toc()
online_tlwhs, online_scores, online_ids = mot_results[ # bs=1 in MOT model
0] # bs=1 in MOT model online_tlwhs, online_scores, online_ids = mot_results[0]
# NOTE: just implement flow statistic for one class
if num_classes == 1:
result = (frame_id + 1, online_tlwhs[0], online_scores[0],
online_ids[0])
statistic = flow_statistic(
result, self.secs_interval, self.do_entrance_counting,
video_fps, entrance, id_set, interval_id_set, in_id_list,
out_id_list, prev_center, records, data_type, num_classes)
records = statistic['records']
fps = 1. / timer.duration fps = 1. / timer.duration
if num_classes == 1 and self.use_reid: if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class # use DeepSORTTracker, only support singe class
results[0].append( results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids)) (frame_id + 1, online_tlwhs, online_scores, online_ids))
...@@ -490,7 +541,9 @@ class SDE_Detector(Detector): ...@@ -490,7 +541,9 @@ class SDE_Detector(Detector):
online_ids, online_ids,
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=fps) fps=fps,
do_entrance_counting=self.do_entrance_counting,
entrance=entrance)
else: else:
# use ByteTracker, support multiple class # use ByteTracker, support multiple class
for cls_id in range(num_classes): for cls_id in range(num_classes):
...@@ -505,13 +558,32 @@ class SDE_Detector(Detector): ...@@ -505,13 +558,32 @@ class SDE_Detector(Detector):
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=fps, fps=fps,
ids2names=[]) ids2names=ids2names,
do_entrance_counting=self.do_entrance_counting,
entrance=entrance,
records=records,
center_traj=center_traj)
writer.write(im) writer.write(im)
if camera_id != -1: if camera_id != -1:
cv2.imshow('Mask Detection', im) cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
result_filename = os.path.join(
self.output_dir,
video_out_name.split('.')[-2] + '_flow_statistic.txt')
f = open(result_filename, 'w')
for line in records:
f.write(line)
print('Flow statistic save in {}'.format(result_filename))
f.close()
writer.release() writer.release()
def predict_mtmct(self, mtmct_dir, mtmct_cfg): def predict_mtmct(self, mtmct_dir, mtmct_cfg):
...@@ -623,18 +695,23 @@ def main(): ...@@ -623,18 +695,23 @@ def main():
arch = yml_conf['arch'] arch = yml_conf['arch']
detector = SDE_Detector( detector = SDE_Detector(
FLAGS.model_dir, FLAGS.model_dir,
FLAGS.tracker_config, tracker_config=FLAGS.tracker_config,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size, batch_size=1,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape, trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn, enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir, output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts,
draw_center_traj=FLAGS.draw_center_traj,
secs_interval=FLAGS.secs_interval,
do_entrance_counting=FLAGS.do_entrance_counting,
reid_model_dir=FLAGS.reid_model_dir, reid_model_dir=FLAGS.reid_model_dir,
mtmct_dir=FLAGS.mtmct_dir, ) mtmct_dir=FLAGS.mtmct_dir, )
......
...@@ -137,6 +137,21 @@ def argsparser(): ...@@ -137,6 +137,21 @@ def argsparser():
type=ast.literal_eval, type=ast.literal_eval,
default=True, default=True,
help='whether to use darkpose to get better keypoint position predict ') help='whether to use darkpose to get better keypoint position predict ')
parser.add_argument(
"--do_entrance_counting",
action='store_true',
help="Whether counting the numbers of identifiers entering "
"or getting out from the entrance. Note that only support one-class"
"counting, multi-class counting is coming soon.")
parser.add_argument(
"--secs_interval",
type=int,
default=2,
help="The seconds interval to count after tracking")
parser.add_argument(
"--draw_center_traj",
action='store_true',
help="Whether drawing the trajectory of center")
parser.add_argument( parser.add_argument(
"--mtmct_dir", "--mtmct_dir",
type=str, type=str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册