Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
bea79e74
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bea79e74
编写于
12月 08, 2021
作者:
F
Feng Ni
提交者:
GitHub
12月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MOT] add mot api for pptracking (#4824)
* add mot jde api for pptracking * add pptracking sde api infer
上级
956726e5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
339 addition
and
85 deletion
+339
-85
deploy/pptracking/python/README.md
deploy/pptracking/python/README.md
+79
-3
deploy/pptracking/python/mot_jde_infer.py
deploy/pptracking/python/mot_jde_infer.py
+94
-36
deploy/pptracking/python/mot_sde_infer.py
deploy/pptracking/python/mot_sde_infer.py
+166
-46
未找到文件。
deploy/pptracking/python/README.md
浏览文件 @
bea79e74
...
@@ -111,6 +111,7 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_
...
@@ -111,6 +111,7 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_
-
DeepSORT算法不支持多类别跟踪,只支持单类别跟踪,且ReID模型最好是与检测模型同一类别的物体训练过的,比如行人跟踪最好使用行人ReID模型,车辆跟踪最好使用车辆ReID模型。
-
DeepSORT算法不支持多类别跟踪,只支持单类别跟踪,且ReID模型最好是与检测模型同一类别的物体训练过的,比如行人跟踪最好使用行人ReID模型,车辆跟踪最好使用车辆ReID模型。
## 3. 跨境跟踪模型的导出和预测
## 3. 跨境跟踪模型的导出和预测
### 3.1 导出预测模型
### 3.1 导出预测模型
Step 1:下载导出的检测模型
Step 1:下载导出的检测模型
...
@@ -129,11 +130,15 @@ tar -xvf deepsort_pplcnet_vehicle.tar
...
@@ -129,11 +130,15 @@ tar -xvf deepsort_pplcnet_vehicle.tar
### 3.2 用导出的模型基于Python去做跨镜头跟踪
### 3.2 用导出的模型基于Python去做跨镜头跟踪
```
bash
```
bash
# 下载demo测试视频
wget https://paddledet.bj.bcebos.com/data/mot/demo/mtmct-demo.tar
tar
-xvf
mtmct-demo.tar
# 用导出的PicoDet车辆检测模型和PPLCNet车辆ReID模型
# 用导出的PicoDet车辆检测模型和PPLCNet车辆ReID模型
python deploy/pptracking/python/mot_sde_infer.py
--model_dir
=
picodet_l_640_aic21mtmct_vehicle/
--reid_model_dir
=
deepsort_pplcnet_vehicle/
--mtmct_dir
=
{
your mtmct scene video folder
}
--mtmct_cfg
=
mtmct_cfg
--device
=
GPU
--scaled
=
True
--threshold
=
0.5
--save_mot_txts
--save_images
python deploy/pptracking/python/mot_sde_infer.py
--model_dir
=
picodet_l_640_aic21mtmct_vehicle/
--reid_model_dir
=
deepsort_pplcnet_vehicle/
--mtmct_dir
=
mtmct-demo
--mtmct_cfg
=
mtmct_cfg
--device
=
GPU
--scaled
=
True
--threshold
=
0.5
--save_mot_txts
--save_images
# 用导出的PP-YOLOv2车辆检测模型和PPLCNet车辆ReID模型
# 用导出的PP-YOLOv2车辆检测模型和PPLCNet车辆ReID模型
python deploy/pptracking/python/mot_sde_infer.py
--model_dir
=
ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/
--reid_model_dir
=
deepsort_pplcnet_vehicle/
--mtmct_dir
=
{
your mtmct scene video folder
}
--mtmct_cfg
=
mtmct_cfg
--device
=
GPU
--scaled
=
True
--threshold
=
0.5
--save_mot_txts
--save_images
python deploy/pptracking/python/mot_sde_infer.py
--model_dir
=
ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/
--reid_model_dir
=
deepsort_pplcnet_vehicle/
--mtmct_dir
=
mtmct-demo
--mtmct_cfg
=
mtmct_cfg
--device
=
GPU
--scaled
=
True
--threshold
=
0.5
--save_mot_txts
--save_images
```
```
**注意:**
**注意:**
...
@@ -146,7 +151,78 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_
...
@@ -146,7 +151,78 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=ppyolov2_r50vd_dcn_
-
`--mtmct_cfg`
是MTMCT预测的某个场景的配置文件,里面包含该一些trick操作的开关和该场景摄像头相关设置的文件路径,用户可以自行更改相关路径以及设置某些操作是否启用。
-
`--mtmct_cfg`
是MTMCT预测的某个场景的配置文件,里面包含该一些trick操作的开关和该场景摄像头相关设置的文件路径,用户可以自行更改相关路径以及设置某些操作是否启用。
## 参数说明:
## 4. API调用方式:
### 4.1 FairMOT模型API调用
```
import mot_jde_infer
# 1.model config and weights
model_dir = 'fairmot_hrnetv2_w18_dlafpn_30e_576x320/'
# 2.inference data
video_file = 'test.mp4'
image_dir = None
# 3.other settings
device = 'CPU' # device should be CPU, GPU or XPU
threshold = 0.3
output_dir = 'output'
# mot predict
mot_jde_infer.predict_naive(model_dir, video_file, image_dir, device, threshold, output_dir)
```
**注意:**
-
以上代码必须进入目录
`PaddleDetection/deploy/pptracking/python`
下执行。
-
支持对视频和图片文件夹进行预测,不支持单张图的预测,
`video_file`
或
`image_dir`
不能同时为None,推荐使用
`video_file`
,而
`image_dir`
需直接存放命名顺序规范的图片。
-
默认会保存跟踪结果可视化后的图片和视频,以及跟踪结果txt文件,默认不会进行轨迹可视化和流量统计。
### 4.2 DeepSORT模型API调用
```
import mot_sde_infer
# 1.model config and weights
model_dir = 'ppyolov2_r50vd_dcn_365e_aic21mtmct_vehicle/'
reid_model_dir = 'deepsort_pplcnet_vehicle/'
# 2.inference data
video_file = 'test.mp4'
image_dir = None
# 3.other settings
scaled = True # set False only when use JDE YOLOv3
device = 'CPU' # device should be CPU, GPU or XPU
threshold = 0.3
output_dir = 'output'
# 4. MTMCT settings, default None
mtmct_dir = None
mtmct_cfg = None
# mot predict
mot_sde_infer.predict_naive(model_dir,
reid_model_dir,
video_file,
image_dir,
mtmct_dir,
mtmct_cfg,
scaled,
device,
threshold,
output_dir)
```
**注意:**
-
以上代码必须进入目录
`PaddleDetection/deploy/pptracking/python`
下执行。
-
支持对视频和图片文件夹进行预测,不支持单张图的预测,
`video_file`
或
`image_dir`
或
`--mtmct_dir`
不能同时为None,推荐使用
`video_file`
,而
`image_dir`
需直接存放命名顺序规范的图片,
`--mtmct_dir`
不为None表示是进行的MTMCT跨镜头跟踪任务。
-
默认会保存跟踪结果可视化后的图片和视频,以及跟踪结果txt文件,默认不会进行轨迹可视化和流量统计。
-
`--scaled`
表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
-
`--mtmct_dir`
是MTMCT预测的某个场景的文件夹名字,里面包含该场景不同摄像头拍摄视频的图片文件夹,其数量至少为两个。
-
`--mtmct_cfg`
是MTMCT预测的某个场景的配置文件,里面包含该一些trick操作的开关和该场景摄像头相关设置的文件路径,用户可以自行更改相关路径以及设置某些操作是否启用。
-
开启MTMCT预测必须将
`video_file`
和
`image_dir`
同时设置为None,且
`--mtmct_dir`
和
`--mtmct_cfg`
都必须不为None。
## 5. 参数说明:
| 参数 | 是否必须|含义 |
| 参数 | 是否必须|含义 |
|-------|-------|----------|
|-------|-------|----------|
...
...
deploy/pptracking/python/mot_jde_infer.py
浏览文件 @
bea79e74
...
@@ -167,7 +167,12 @@ class JDE_Detector(Detector):
...
@@ -167,7 +167,12 @@ class JDE_Detector(Detector):
return
online_tlwhs
,
online_scores
,
online_ids
return
online_tlwhs
,
online_scores
,
online_ids
def
predict_image
(
detector
,
image_list
):
def
predict_image
(
detector
,
image_list
,
threshold
,
output_dir
,
save_images
=
True
,
run_benchmark
=
False
):
results
=
[]
results
=
[]
num_classes
=
detector
.
num_classes
num_classes
=
detector
.
num_classes
data_type
=
'mcmot'
if
num_classes
>
1
else
'mot'
data_type
=
'mcmot'
if
num_classes
>
1
else
'mot'
...
@@ -176,13 +181,11 @@ def predict_image(detector, image_list):
...
@@ -176,13 +181,11 @@ def predict_image(detector, image_list):
image_list
.
sort
()
image_list
.
sort
()
for
frame_id
,
img_file
in
enumerate
(
image_list
):
for
frame_id
,
img_file
in
enumerate
(
image_list
):
frame
=
cv2
.
imread
(
img_file
)
frame
=
cv2
.
imread
(
img_file
)
if
FLAGS
.
run_benchmark
:
if
run_benchmark
:
# warmup
# warmup
detector
.
predict
(
detector
.
predict
([
img_file
],
threshold
,
repeats
=
10
,
add_timer
=
False
)
[
img_file
],
FLAGS
.
threshold
,
repeats
=
10
,
add_timer
=
False
)
# run benchmark
# run benchmark
detector
.
predict
(
detector
.
predict
([
img_file
],
threshold
,
repeats
=
10
,
add_timer
=
True
)
[
img_file
],
FLAGS
.
threshold
,
repeats
=
10
,
add_timer
=
True
)
cm
,
gm
,
gu
=
get_current_memory_mb
()
cm
,
gm
,
gu
=
get_current_memory_mb
()
detector
.
cpu_mem
+=
cm
detector
.
cpu_mem
+=
cm
detector
.
gpu_mem
+=
gm
detector
.
gpu_mem
+=
gm
...
@@ -190,7 +193,7 @@ def predict_image(detector, image_list):
...
@@ -190,7 +193,7 @@ def predict_image(detector, image_list):
print
(
'Test iter {}, file name:{}'
.
format
(
frame_id
,
img_file
))
print
(
'Test iter {}, file name:{}'
.
format
(
frame_id
,
img_file
))
else
:
else
:
online_tlwhs
,
online_scores
,
online_ids
=
detector
.
predict
(
online_tlwhs
,
online_scores
,
online_ids
=
detector
.
predict
(
[
img_file
],
FLAGS
.
threshold
)
[
img_file
],
threshold
)
online_im
=
plot_tracking_dict
(
online_im
=
plot_tracking_dict
(
frame
,
frame
,
num_classes
,
num_classes
,
...
@@ -199,22 +202,32 @@ def predict_image(detector, image_list):
...
@@ -199,22 +202,32 @@ def predict_image(detector, image_list):
online_scores
,
online_scores
,
frame_id
=
frame_id
,
frame_id
=
frame_id
,
ids2names
=
ids2names
)
ids2names
=
ids2names
)
if
FLAGS
.
save_images
:
if
save_images
:
if
not
os
.
path
.
exists
(
FLAGS
.
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
FLAGS
.
output_dir
)
os
.
makedirs
(
output_dir
)
img_name
=
os
.
path
.
split
(
img_file
)[
-
1
]
img_name
=
os
.
path
.
split
(
img_file
)[
-
1
]
out_path
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
img_name
)
out_path
=
os
.
path
.
join
(
output_dir
,
img_name
)
cv2
.
imwrite
(
out_path
,
online_im
)
cv2
.
imwrite
(
out_path
,
online_im
)
print
(
"save result to: "
+
out_path
)
print
(
"save result to: "
+
out_path
)
def
predict_video
(
detector
,
camera_id
):
def
predict_video
(
detector
,
video_file
,
threshold
,
output_dir
,
save_images
=
True
,
save_mot_txts
=
True
,
draw_center_traj
=
False
,
secs_interval
=
10
,
do_entrance_counting
=
False
,
camera_id
=-
1
):
video_name
=
'mot_output.mp4'
video_name
=
'mot_output.mp4'
if
camera_id
!=
-
1
:
if
camera_id
!=
-
1
:
capture
=
cv2
.
VideoCapture
(
camera_id
)
capture
=
cv2
.
VideoCapture
(
camera_id
)
else
:
else
:
capture
=
cv2
.
VideoCapture
(
FLAGS
.
video_file
)
capture
=
cv2
.
VideoCapture
(
video_file
)
video_name
=
os
.
path
.
split
(
FLAGS
.
video_file
)[
-
1
]
video_name
=
os
.
path
.
split
(
video_file
)[
-
1
]
# Get Video info : resolution, fps, frame count
# Get Video info : resolution, fps, frame count
width
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
width
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
height
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
height
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
...
@@ -222,10 +235,10 @@ def predict_video(detector, camera_id):
...
@@ -222,10 +235,10 @@ def predict_video(detector, camera_id):
frame_count
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
frame_count
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
print
(
"fps: %d, frame_count: %d"
%
(
fps
,
frame_count
))
print
(
"fps: %d, frame_count: %d"
%
(
fps
,
frame_count
))
if
not
os
.
path
.
exists
(
FLAGS
.
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
FLAGS
.
output_dir
)
os
.
makedirs
(
output_dir
)
out_path
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
)
out_path
=
os
.
path
.
join
(
output_dir
,
video_name
)
if
not
FLAGS
.
save_images
:
if
not
save_images
:
video_format
=
'mp4v'
video_format
=
'mp4v'
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
video_format
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
video_format
)
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
...
@@ -238,7 +251,7 @@ def predict_video(detector, camera_id):
...
@@ -238,7 +251,7 @@ def predict_video(detector, camera_id):
center_traj
=
None
center_traj
=
None
entrance
=
None
entrance
=
None
records
=
None
records
=
None
if
FLAGS
.
draw_center_traj
:
if
draw_center_traj
:
center_traj
=
[{}
for
i
in
range
(
num_classes
)]
center_traj
=
[{}
for
i
in
range
(
num_classes
)]
if
num_classes
==
1
:
if
num_classes
==
1
:
...
@@ -257,8 +270,8 @@ def predict_video(detector, camera_id):
...
@@ -257,8 +270,8 @@ def predict_video(detector, camera_id):
if
not
ret
:
if
not
ret
:
break
break
timer
.
tic
()
timer
.
tic
()
online_tlwhs
,
online_scores
,
online_ids
=
detector
.
predict
(
online_tlwhs
,
online_scores
,
online_ids
=
detector
.
predict
(
[
frame
],
[
frame
],
FLAGS
.
threshold
)
threshold
)
timer
.
toc
()
timer
.
toc
()
for
cls_id
in
range
(
num_classes
):
for
cls_id
in
range
(
num_classes
):
...
@@ -271,9 +284,9 @@ def predict_video(detector, camera_id):
...
@@ -271,9 +284,9 @@ def predict_video(detector, camera_id):
result
=
(
frame_id
+
1
,
online_tlwhs
[
0
],
online_scores
[
0
],
result
=
(
frame_id
+
1
,
online_tlwhs
[
0
],
online_scores
[
0
],
online_ids
[
0
])
online_ids
[
0
])
statistic
=
flow_statistic
(
statistic
=
flow_statistic
(
result
,
FLAGS
.
secs_interval
,
FLAGS
.
do_entrance_counting
,
result
,
secs_interval
,
do_entrance_counting
,
video_fps
,
video_fps
,
entrance
,
id_set
,
interval_id_set
,
in
_id_list
,
entrance
,
id_set
,
interval_id_set
,
in_id_list
,
out
_id_list
,
out_id_list
,
prev_center
,
records
,
data_type
,
num_classes
)
prev_center
,
records
,
data_type
,
num_classes
)
id_set
=
statistic
[
'id_set'
]
id_set
=
statistic
[
'id_set'
]
interval_id_set
=
statistic
[
'interval_id_set'
]
interval_id_set
=
statistic
[
'interval_id_set'
]
in_id_list
=
statistic
[
'in_id_list'
]
in_id_list
=
statistic
[
'in_id_list'
]
...
@@ -281,7 +294,7 @@ def predict_video(detector, camera_id):
...
@@ -281,7 +294,7 @@ def predict_video(detector, camera_id):
prev_center
=
statistic
[
'prev_center'
]
prev_center
=
statistic
[
'prev_center'
]
records
=
statistic
[
'records'
]
records
=
statistic
[
'records'
]
elif
num_classes
>
1
and
FLAGS
.
do_entrance_counting
:
elif
num_classes
>
1
and
do_entrance_counting
:
raise
NotImplementedError
(
raise
NotImplementedError
(
'Multi-class flow counting is not implemented now!'
)
'Multi-class flow counting is not implemented now!'
)
im
=
plot_tracking_dict
(
im
=
plot_tracking_dict
(
...
@@ -293,13 +306,13 @@ def predict_video(detector, camera_id):
...
@@ -293,13 +306,13 @@ def predict_video(detector, camera_id):
frame_id
=
frame_id
,
frame_id
=
frame_id
,
fps
=
fps
,
fps
=
fps
,
ids2names
=
ids2names
,
ids2names
=
ids2names
,
do_entrance_counting
=
FLAGS
.
do_entrance_counting
,
do_entrance_counting
=
do_entrance_counting
,
entrance
=
entrance
,
entrance
=
entrance
,
records
=
records
,
records
=
records
,
center_traj
=
center_traj
)
center_traj
=
center_traj
)
if
FLAGS
.
save_images
:
if
save_images
:
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
save_dir
=
os
.
path
.
join
(
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
if
not
os
.
path
.
exists
(
save_dir
):
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
os
.
makedirs
(
save_dir
)
cv2
.
imwrite
(
cv2
.
imwrite
(
...
@@ -313,24 +326,23 @@ def predict_video(detector, camera_id):
...
@@ -313,24 +326,23 @@ def predict_video(detector, camera_id):
cv2
.
imshow
(
'Tracking Detection'
,
im
)
cv2
.
imshow
(
'Tracking Detection'
,
im
)
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
break
if
FLAGS
.
save_mot_txts
:
if
save_mot_txts
:
result_filename
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
result_filename
=
os
.
path
.
join
(
output_dir
,
video_name
.
split
(
'.'
)[
-
2
]
+
'.txt'
)
video_name
.
split
(
'.'
)[
-
2
]
+
'.txt'
)
write_mot_results
(
result_filename
,
results
,
data_type
,
num_classes
)
write_mot_results
(
result_filename
,
results
,
data_type
,
num_classes
)
if
num_classes
==
1
:
if
num_classes
==
1
:
result_filename
=
os
.
path
.
join
(
result_filename
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
output_dir
,
video_name
.
split
(
'.'
)[
-
2
]
+
'_flow_statistic.txt'
)
video_name
.
split
(
'.'
)[
-
2
]
+
'_flow_statistic.txt'
)
f
=
open
(
result_filename
,
'w'
)
f
=
open
(
result_filename
,
'w'
)
for
line
in
records
:
for
line
in
records
:
f
.
write
(
line
)
f
.
write
(
line
)
print
(
'Flow statistic save in {}'
.
format
(
result_filename
))
print
(
'Flow statistic save in {}'
.
format
(
result_filename
))
f
.
close
()
f
.
close
()
if
FLAGS
.
save_images
:
if
save_images
:
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
save_dir
=
os
.
path
.
join
(
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
cmd_str
=
'ffmpeg -f image2 -i {}/%05d.jpg {}'
.
format
(
save_dir
,
cmd_str
=
'ffmpeg -f image2 -i {}/%05d.jpg {}'
.
format
(
save_dir
,
out_path
)
out_path
)
os
.
system
(
cmd_str
)
os
.
system
(
cmd_str
)
...
@@ -339,6 +351,36 @@ def predict_video(detector, camera_id):
...
@@ -339,6 +351,36 @@ def predict_video(detector, camera_id):
writer
.
release
()
writer
.
release
()
def
predict_naive
(
model_dir
,
video_file
,
image_dir
,
device
=
'gpu'
,
threshold
=
0.5
,
output_dir
=
'output'
):
pred_config
=
PredictConfig
(
model_dir
)
detector
=
JDE_Detector
(
pred_config
,
model_dir
,
device
=
device
.
upper
())
if
video_file
is
not
None
:
predict_video
(
detector
,
video_file
,
threshold
=
threshold
,
output_dir
=
output_dir
,
save_images
=
True
,
save_mot_txts
=
True
,
draw_center_traj
=
False
,
secs_interval
=
10
,
do_entrance_counting
=
False
)
else
:
img_list
=
get_test_images
(
image_dir
,
infer_img
=
None
)
predict_image
(
detector
,
img_list
,
threshold
=
threshold
,
output_dir
=
output_dir
,
save_images
=
True
)
def
main
():
def
main
():
pred_config
=
PredictConfig
(
FLAGS
.
model_dir
)
pred_config
=
PredictConfig
(
FLAGS
.
model_dir
)
detector
=
JDE_Detector
(
detector
=
JDE_Detector
(
...
@@ -355,11 +397,27 @@ def main():
...
@@ -355,11 +397,27 @@ def main():
# predict from video file or camera video stream
# predict from video file or camera video stream
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
predict_video
(
detector
,
FLAGS
.
camera_id
)
predict_video
(
detector
,
FLAGS
.
video_file
,
threshold
=
FLAGS
.
threshold
,
output_dir
=
FLAGS
.
output_dir
,
save_images
=
FLAGS
.
save_images
,
save_mot_txts
=
FLAGS
.
save_mot_txts
,
draw_center_traj
=
FLAGS
.
draw_center_traj
,
secs_interval
=
FLAGS
.
secs_interval
,
do_entrance_counting
=
FLAGS
.
do_entrance_counting
,
camera_id
=
FLAGS
.
camera_id
)
else
:
else
:
# predict from image
# predict from image
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
predict_image
(
detector
,
img_list
)
predict_image
(
detector
,
img_list
,
threshold
=
FLAGS
.
threshold
,
output_dir
=
FLAGS
.
output_dir
,
save_images
=
FLAGS
.
save_images
,
run_benchmark
=
FLAGS
.
run_benchmark
)
if
not
FLAGS
.
run_benchmark
:
if
not
FLAGS
.
run_benchmark
:
detector
.
det_times
.
info
(
average
=
True
)
detector
.
det_times
.
info
(
average
=
True
)
else
:
else
:
...
...
deploy/pptracking/python/mot_sde_infer.py
浏览文件 @
bea79e74
...
@@ -316,6 +316,8 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
...
@@ -316,6 +316,8 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
self
.
det_times
.
preprocess_time_s
.
end
()
self
.
det_times
.
preprocess_time_s
.
end
()
self
.
det_times
.
inference_time_s
.
start
()
self
.
det_times
.
inference_time_s
.
start
()
np_score_list
,
np_boxes_list
=
[],
[]
# model prediction
# model prediction
for
i
in
range
(
repeats
):
for
i
in
range
(
repeats
):
self
.
predictor
.
run
()
self
.
predictor
.
run
()
...
@@ -549,26 +551,33 @@ class SDE_ReID(object):
...
@@ -549,26 +551,33 @@ class SDE_ReID(object):
return
tracking_outs
return
tracking_outs
def
predict_image
(
detector
,
reid_model
,
image_list
):
def
predict_image
(
detector
,
reid_model
,
image_list
,
threshold
,
output_dir
,
scaled
=
True
,
save_images
=
True
,
run_benchmark
=
False
):
image_list
.
sort
()
image_list
.
sort
()
for
i
,
img_file
in
enumerate
(
image_list
):
for
i
,
img_file
in
enumerate
(
image_list
):
frame
=
cv2
.
imread
(
img_file
)
frame
=
cv2
.
imread
(
img_file
)
ori_image_shape
=
list
(
frame
.
shape
[:
2
])
ori_image_shape
=
list
(
frame
.
shape
[:
2
])
if
FLAGS
.
run_benchmark
:
if
run_benchmark
:
# warmup
# warmup
pred_dets
,
pred_xyxys
=
detector
.
predict
(
pred_dets
,
pred_xyxys
=
detector
.
predict
(
[
img_file
],
[
img_file
],
ori_image_shape
,
ori_image_shape
,
FLAGS
.
threshold
,
threshold
,
FLAGS
.
scaled
,
scaled
,
repeats
=
10
,
repeats
=
10
,
add_timer
=
False
)
add_timer
=
False
)
# run benchmark
# run benchmark
pred_dets
,
pred_xyxys
=
detector
.
predict
(
pred_dets
,
pred_xyxys
=
detector
.
predict
(
[
img_file
],
[
img_file
],
ori_image_shape
,
ori_image_shape
,
FLAGS
.
threshold
,
threshold
,
FLAGS
.
scaled
,
scaled
,
repeats
=
10
,
repeats
=
10
,
add_timer
=
True
)
add_timer
=
True
)
...
@@ -579,7 +588,7 @@ def predict_image(detector, reid_model, image_list):
...
@@ -579,7 +588,7 @@ def predict_image(detector, reid_model, image_list):
print
(
'Test iter {}, file name:{}'
.
format
(
i
,
img_file
))
print
(
'Test iter {}, file name:{}'
.
format
(
i
,
img_file
))
else
:
else
:
pred_dets
,
pred_xyxys
=
detector
.
predict
(
pred_dets
,
pred_xyxys
=
detector
.
predict
(
[
img_file
],
ori_image_shape
,
FLAGS
.
threshold
,
FLAGS
.
scaled
)
[
img_file
],
ori_image_shape
,
threshold
,
scaled
)
if
len
(
pred_dets
)
==
1
and
np
.
sum
(
pred_dets
)
==
0
:
if
len
(
pred_dets
)
==
1
and
np
.
sum
(
pred_dets
)
==
0
:
print
(
'Frame {} has no object, try to modify score threshold.'
.
print
(
'Frame {} has no object, try to modify score threshold.'
.
...
@@ -589,7 +598,7 @@ def predict_image(detector, reid_model, image_list):
...
@@ -589,7 +598,7 @@ def predict_image(detector, reid_model, image_list):
# reid process
# reid process
crops
=
reid_model
.
get_crops
(
pred_xyxys
,
frame
)
crops
=
reid_model
.
get_crops
(
pred_xyxys
,
frame
)
if
FLAGS
.
run_benchmark
:
if
run_benchmark
:
# warmup
# warmup
tracking_outs
=
reid_model
.
predict
(
tracking_outs
=
reid_model
.
predict
(
crops
,
pred_dets
,
repeats
=
10
,
add_timer
=
False
)
crops
,
pred_dets
,
repeats
=
10
,
add_timer
=
False
)
...
@@ -607,22 +616,34 @@ def predict_image(detector, reid_model, image_list):
...
@@ -607,22 +616,34 @@ def predict_image(detector, reid_model, image_list):
online_im
=
plot_tracking
(
online_im
=
plot_tracking
(
frame
,
online_tlwhs
,
online_ids
,
online_scores
,
frame_id
=
i
)
frame
,
online_tlwhs
,
online_ids
,
online_scores
,
frame_id
=
i
)
if
FLAGS
.
save_images
:
if
save_images
:
if
not
os
.
path
.
exists
(
FLAGS
.
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
FLAGS
.
output_dir
)
os
.
makedirs
(
output_dir
)
img_name
=
os
.
path
.
split
(
img_file
)[
-
1
]
img_name
=
os
.
path
.
split
(
img_file
)[
-
1
]
out_path
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
img_name
)
out_path
=
os
.
path
.
join
(
output_dir
,
img_name
)
cv2
.
imwrite
(
out_path
,
online_im
)
cv2
.
imwrite
(
out_path
,
online_im
)
print
(
"save result to: "
+
out_path
)
print
(
"save result to: "
+
out_path
)
def
predict_video
(
detector
,
reid_model
,
camera_id
):
def
predict_video
(
detector
,
reid_model
,
video_file
,
scaled
,
threshold
,
output_dir
,
save_images
=
True
,
save_mot_txts
=
True
,
draw_center_traj
=
False
,
secs_interval
=
10
,
do_entrance_counting
=
False
,
camera_id
=-
1
):
video_name
=
'mot_output.mp4'
if
camera_id
!=
-
1
:
if
camera_id
!=
-
1
:
capture
=
cv2
.
VideoCapture
(
camera_id
)
capture
=
cv2
.
VideoCapture
(
camera_id
)
video_name
=
'mot_output.mp4'
else
:
else
:
capture
=
cv2
.
VideoCapture
(
FLAGS
.
video_file
)
capture
=
cv2
.
VideoCapture
(
video_file
)
video_name
=
os
.
path
.
split
(
FLAGS
.
video_file
)[
-
1
]
video_name
=
os
.
path
.
split
(
video_file
)[
-
1
]
# Get Video info : resolution, fps, frame count
# Get Video info : resolution, fps, frame count
width
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
width
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
height
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
height
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
...
@@ -630,10 +651,10 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -630,10 +651,10 @@ def predict_video(detector, reid_model, camera_id):
frame_count
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
frame_count
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
print
(
"fps: %d, frame_count: %d"
%
(
fps
,
frame_count
))
print
(
"fps: %d, frame_count: %d"
%
(
fps
,
frame_count
))
if
not
os
.
path
.
exists
(
FLAGS
.
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
FLAGS
.
output_dir
)
os
.
makedirs
(
output_dir
)
out_path
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
)
out_path
=
os
.
path
.
join
(
output_dir
,
video_name
)
if
not
FLAGS
.
save_images
:
if
not
save_images
:
video_format
=
'mp4v'
video_format
=
'mp4v'
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
video_format
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
video_format
)
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
...
@@ -656,7 +677,7 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -656,7 +677,7 @@ def predict_video(detector, reid_model, camera_id):
timer
.
tic
()
timer
.
tic
()
ori_image_shape
=
list
(
frame
.
shape
[:
2
])
ori_image_shape
=
list
(
frame
.
shape
[:
2
])
pred_dets
,
pred_xyxys
=
detector
.
predict
([
frame
],
ori_image_shape
,
pred_dets
,
pred_xyxys
=
detector
.
predict
([
frame
],
ori_image_shape
,
FLAGS
.
threshold
,
FLAGS
.
scaled
)
threshold
,
scaled
)
if
len
(
pred_dets
)
==
1
and
np
.
sum
(
pred_dets
)
==
0
:
if
len
(
pred_dets
)
==
1
and
np
.
sum
(
pred_dets
)
==
0
:
print
(
'Frame {} has no object, try to modify score threshold.'
.
print
(
'Frame {} has no object, try to modify score threshold.'
.
...
@@ -677,9 +698,9 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -677,9 +698,9 @@ def predict_video(detector, reid_model, camera_id):
# NOTE: just implement flow statistic for one class
# NOTE: just implement flow statistic for one class
result
=
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
)
result
=
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
)
statistic
=
flow_statistic
(
statistic
=
flow_statistic
(
result
,
FLAGS
.
secs_interval
,
FLAGS
.
do_entrance_counting
,
result
,
secs_interval
,
do_entrance_counting
,
video_fps
,
video_fps
,
entrance
,
id_set
,
interval_id_set
,
in
_id_list
,
entrance
,
id_set
,
interval_id_set
,
in_id_list
,
out
_id_list
,
out_id_list
,
prev_center
,
records
)
prev_center
,
records
)
id_set
=
statistic
[
'id_set'
]
id_set
=
statistic
[
'id_set'
]
interval_id_set
=
statistic
[
'interval_id_set'
]
interval_id_set
=
statistic
[
'interval_id_set'
]
in_id_list
=
statistic
[
'in_id_list'
]
in_id_list
=
statistic
[
'in_id_list'
]
...
@@ -697,11 +718,11 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -697,11 +718,11 @@ def predict_video(detector, reid_model, camera_id):
online_scores
,
online_scores
,
frame_id
=
frame_id
,
frame_id
=
frame_id
,
fps
=
fps
,
fps
=
fps
,
do_entrance_counting
=
FLAGS
.
do_entrance_counting
,
do_entrance_counting
=
do_entrance_counting
,
entrance
=
entrance
)
entrance
=
entrance
)
if
FLAGS
.
save_images
:
if
save_images
:
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
save_dir
=
os
.
path
.
join
(
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
if
not
os
.
path
.
exists
(
save_dir
):
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
os
.
makedirs
(
save_dir
)
cv2
.
imwrite
(
cv2
.
imwrite
(
...
@@ -717,21 +738,21 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -717,21 +738,21 @@ def predict_video(detector, reid_model, camera_id):
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
break
if
FLAGS
.
save_mot_txts
:
if
save_mot_txts
:
result_filename
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
result_filename
=
os
.
path
.
join
(
output_dir
,
video_name
.
split
(
'.'
)[
-
2
]
+
'.txt'
)
video_name
.
split
(
'.'
)[
-
2
]
+
'.txt'
)
write_mot_results
(
result_filename
,
results
)
write_mot_results
(
result_filename
,
results
)
result_filename
=
os
.
path
.
join
(
result_filename
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
.
split
(
'.'
)[
-
2
]
+
'_flow_statistic.txt'
)
output_dir
,
video_name
.
split
(
'.'
)[
-
2
]
+
'_flow_statistic.txt'
)
f
=
open
(
result_filename
,
'w'
)
f
=
open
(
result_filename
,
'w'
)
for
line
in
records
:
for
line
in
records
:
f
.
write
(
line
)
f
.
write
(
line
)
print
(
'Flow statistic save in {}'
.
format
(
result_filename
))
print
(
'Flow statistic save in {}'
.
format
(
result_filename
))
f
.
close
()
f
.
close
()
if
FLAGS
.
save_images
:
if
save_images
:
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
save_dir
=
os
.
path
.
join
(
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
cmd_str
=
'ffmpeg -f image2 -i {}/%05d.jpg {}'
.
format
(
save_dir
,
cmd_str
=
'ffmpeg -f image2 -i {}/%05d.jpg {}'
.
format
(
save_dir
,
out_path
)
out_path
)
os
.
system
(
cmd_str
)
os
.
system
(
cmd_str
)
...
@@ -740,8 +761,16 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -740,8 +761,16 @@ def predict_video(detector, reid_model, camera_id):
writer
.
release
()
writer
.
release
()
def
predict_mtmct_seq
(
detector
,
reid_model
,
seq_name
,
output_dir
):
def
predict_mtmct_seq
(
detector
,
fpath
=
os
.
path
.
join
(
FLAGS
.
mtmct_dir
,
seq_name
)
reid_model
,
mtmct_dir
,
seq_name
,
scaled
,
threshold
,
output_dir
,
save_images
=
True
,
save_mot_txts
=
True
):
fpath
=
os
.
path
.
join
(
mtmct_dir
,
seq_name
)
if
os
.
path
.
exists
(
os
.
path
.
join
(
fpath
,
'img1'
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
fpath
,
'img1'
)):
fpath
=
os
.
path
.
join
(
fpath
,
'img1'
)
fpath
=
os
.
path
.
join
(
fpath
,
'img1'
)
...
@@ -756,13 +785,13 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
...
@@ -756,13 +785,13 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
len
(
image_list
),
seq_name
))
len
(
image_list
),
seq_name
))
for
frame_id
,
img_file
in
enumerate
(
image_list
):
for
frame_id
,
img_file
in
enumerate
(
image_list
):
if
frame_id
%
4
0
==
0
:
if
frame_id
%
1
0
==
0
:
print
(
'Processing frame {} of seq {}.'
.
format
(
frame_id
,
seq_name
))
print
(
'Processing frame {} of seq {}.'
.
format
(
frame_id
,
seq_name
))
frame
=
cv2
.
imread
(
os
.
path
.
join
(
fpath
,
img_file
))
frame
=
cv2
.
imread
(
os
.
path
.
join
(
fpath
,
img_file
))
ori_image_shape
=
list
(
frame
.
shape
[:
2
])
ori_image_shape
=
list
(
frame
.
shape
[:
2
])
frame_path
=
os
.
path
.
join
(
fpath
,
img_file
)
frame_path
=
os
.
path
.
join
(
fpath
,
img_file
)
pred_dets
,
pred_xyxys
=
detector
.
predict
([
frame_path
],
ori_image_shape
,
pred_dets
,
pred_xyxys
=
detector
.
predict
([
frame_path
],
ori_image_shape
,
FLAGS
.
threshold
,
FLAGS
.
scaled
)
threshold
,
scaled
)
if
len
(
pred_dets
)
==
1
and
np
.
sum
(
pred_dets
)
==
0
:
if
len
(
pred_dets
)
==
1
and
np
.
sum
(
pred_dets
)
==
0
:
print
(
'Frame {} has no object, try to modify score threshold.'
.
print
(
'Frame {} has no object, try to modify score threshold.'
.
...
@@ -791,21 +820,29 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
...
@@ -791,21 +820,29 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
results
[
0
].
append
(
results
[
0
].
append
(
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
))
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
))
if
FLAGS
.
save_images
:
if
save_images
:
save_dir
=
os
.
path
.
join
(
output_dir
,
seq_name
)
save_dir
=
os
.
path
.
join
(
output_dir
,
seq_name
)
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
img_name
=
os
.
path
.
split
(
img_file
)[
-
1
]
img_name
=
os
.
path
.
split
(
img_file
)[
-
1
]
out_path
=
os
.
path
.
join
(
save_dir
,
img_name
)
out_path
=
os
.
path
.
join
(
save_dir
,
img_name
)
cv2
.
imwrite
(
out_path
,
online_im
)
cv2
.
imwrite
(
out_path
,
online_im
)
if
FLAGS
.
save_mot_txts
:
if
save_mot_txts
:
result_filename
=
os
.
path
.
join
(
output_dir
,
seq_name
+
'.txt'
)
result_filename
=
os
.
path
.
join
(
output_dir
,
seq_name
+
'.txt'
)
write_mot_results
(
result_filename
,
results
)
write_mot_results
(
result_filename
,
results
)
return
mot_features_dict
return
mot_features_dict
def
predict_mtmct
(
detector
,
reid_model
,
mtmct_dir
,
mtmct_cfg
):
def
predict_mtmct
(
detector
,
reid_model
,
mtmct_dir
,
mtmct_cfg
,
scaled
,
threshold
,
output_dir
,
save_images
=
True
,
save_mot_txts
=
True
):
MTMCT
=
mtmct_cfg
[
'MTMCT'
]
MTMCT
=
mtmct_cfg
[
'MTMCT'
]
assert
MTMCT
==
True
,
'predict_mtmct should be used for MTMCT.'
assert
MTMCT
==
True
,
'predict_mtmct should be used for MTMCT.'
...
@@ -832,7 +869,6 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
...
@@ -832,7 +869,6 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
mot_list_breaks
=
[]
mot_list_breaks
=
[]
cid_tid_dict
=
dict
()
cid_tid_dict
=
dict
()
output_dir
=
FLAGS
.
output_dir
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
seqs
=
os
.
listdir
(
mtmct_dir
)
seqs
=
os
.
listdir
(
mtmct_dir
)
...
@@ -852,8 +888,9 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
...
@@ -852,8 +888,9 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
print
(
'{} is not a image folder.'
.
format
(
fpath
))
print
(
'{} is not a image folder.'
.
format
(
fpath
))
continue
continue
mot_features_dict
=
predict_mtmct_seq
(
detector
,
reid_model
,
seq
,
mot_features_dict
=
predict_mtmct_seq
(
output_dir
)
detector
,
reid_model
,
mtmct_dir
,
seq
,
scaled
,
threshold
,
output_dir
,
save_images
,
save_mot_txts
)
cid
=
int
(
re
.
sub
(
'[a-z,A-Z]'
,
""
,
seq
))
cid
=
int
(
re
.
sub
(
'[a-z,A-Z]'
,
""
,
seq
))
tid_data
,
mot_list_break
=
trajectory_fusion
(
tid_data
,
mot_list_break
=
trajectory_fusion
(
...
@@ -911,6 +948,62 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
...
@@ -911,6 +948,62 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
print_mtmct_result
(
data_root_gt
,
pred_mtmct_file
)
print_mtmct_result
(
data_root_gt
,
pred_mtmct_file
)
def
predict_naive
(
model_dir
,
reid_model_dir
,
video_file
,
image_dir
,
mtmct_dir
=
None
,
mtmct_cfg
=
None
,
scaled
=
True
,
device
=
'gpu'
,
threshold
=
0.5
,
output_dir
=
'output'
):
pred_config
=
PredictConfig
(
model_dir
)
detector_func
=
'SDE_Detector'
if
pred_config
.
arch
==
'PicoDet'
:
detector_func
=
'SDE_DetectorPicoDet'
detector
=
eval
(
detector_func
)(
pred_config
,
model_dir
,
device
=
device
)
pred_config
=
PredictConfig
(
reid_model_dir
)
reid_model
=
SDE_ReID
(
pred_config
,
reid_model_dir
,
device
=
device
)
if
video_file
is
not
None
:
predict_video
(
detector
,
reid_model
,
video_file
,
scaled
=
scaled
,
threshold
=
threshold
,
output_dir
=
output_dir
,
save_images
=
True
,
save_mot_txts
=
True
,
draw_center_traj
=
False
,
secs_interval
=
10
,
do_entrance_counting
=
False
)
elif
mtmct_dir
is
not
None
:
with
open
(
mtmct_cfg
)
as
f
:
mtmct_cfg_file
=
yaml
.
safe_load
(
f
)
predict_mtmct
(
detector
,
reid_model
,
mtmct_dir
,
mtmct_cfg_file
,
scaled
=
scaled
,
threshold
=
threshold
,
output_dir
=
output_dir
,
save_images
=
True
,
save_mot_txts
=
True
)
else
:
img_list
=
get_test_images
(
image_dir
,
infer_img
=
None
)
predict_image
(
detector
,
reid_model
,
img_list
,
threshold
=
threshold
,
output_dir
=
output_dir
,
save_images
=
True
)
def
main
():
def
main
():
pred_config
=
PredictConfig
(
FLAGS
.
model_dir
)
pred_config
=
PredictConfig
(
FLAGS
.
model_dir
)
detector_func
=
'SDE_Detector'
detector_func
=
'SDE_Detector'
...
@@ -945,18 +1038,45 @@ def main():
...
@@ -945,18 +1038,45 @@ def main():
# predict from video file or camera video stream
# predict from video file or camera video stream
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
predict_video
(
detector
,
reid_model
,
FLAGS
.
camera_id
)
predict_video
(
detector
,
reid_model
,
FLAGS
.
video_file
,
scaled
=
FLAGS
.
scaled
,
threshold
=
FLAGS
.
threshold
,
output_dir
=
FLAGS
.
output_dir
,
save_images
=
FLAGS
.
save_images
,
save_mot_txts
=
FLAGS
.
save_mot_txts
,
draw_center_traj
=
FLAGS
.
draw_center_traj
,
secs_interval
=
FLAGS
.
secs_interval
,
do_entrance_counting
=
FLAGS
.
do_entrance_counting
,
camera_id
=
FLAGS
.
camera_id
)
elif
FLAGS
.
mtmct_dir
is
not
None
:
elif
FLAGS
.
mtmct_dir
is
not
None
:
mtmct_cfg_file
=
FLAGS
.
mtmct_cfg
mtmct_cfg_file
=
FLAGS
.
mtmct_cfg
with
open
(
mtmct_cfg_file
)
as
f
:
with
open
(
mtmct_cfg_file
)
as
f
:
mtmct_cfg
=
yaml
.
safe_load
(
f
)
mtmct_cfg
=
yaml
.
safe_load
(
f
)
predict_mtmct
(
detector
,
reid_model
,
FLAGS
.
mtmct_dir
,
mtmct_cfg
)
predict_mtmct
(
detector
,
reid_model
,
FLAGS
.
mtmct_dir
,
mtmct_cfg
,
scaled
=
FLAGS
.
scaled
,
threshold
=
FLAGS
.
threshold
,
output_dir
=
FLAGS
.
output_dir
,
save_images
=
FLAGS
.
save_images
,
save_mot_txts
=
FLAGS
.
save_mot_txts
)
else
:
else
:
# predict from image
# predict from image
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
predict_image
(
detector
,
reid_model
,
img_list
)
predict_image
(
detector
,
reid_model
,
img_list
,
threshold
=
FLAGS
.
threshold
,
output_dir
=
FLAGS
.
output_dir
,
save_images
=
FLAGS
.
save_images
,
run_benchmark
=
FLAGS
.
run_benchmark
)
if
not
FLAGS
.
run_benchmark
:
if
not
FLAGS
.
run_benchmark
:
detector
.
det_times
.
info
(
average
=
True
)
detector
.
det_times
.
info
(
average
=
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录