Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
822c7933
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
822c7933
编写于
4月 08, 2022
作者:
F
Feng Ni
提交者:
GitHub
4月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick][MOT] fix deploy python mot infer (#5630)
* fix deploy python mot infer, fix cfgs * fix doc, test=document_fix
上级
20b2f101
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
434 addition
and
157 deletion
+434
-157
configs/mot/jde/_base_/jde_darknet53.yml
configs/mot/jde/_base_/jde_darknet53.yml
+1
-1
deploy/pphuman/config/tracker_config.yml
deploy/pphuman/config/tracker_config.yml
+2
-2
deploy/pptracking/python/mot/tracker/jde_tracker.py
deploy/pptracking/python/mot/tracker/jde_tracker.py
+10
-13
deploy/pptracking/python/mot_jde_infer.py
deploy/pptracking/python/mot_jde_infer.py
+3
-3
deploy/pptracking/python/mot_sde_infer.py
deploy/pptracking/python/mot_sde_infer.py
+4
-4
deploy/pptracking/python/tracker_config.yml
deploy/pptracking/python/tracker_config.yml
+2
-2
deploy/python/README.md
deploy/python/README.md
+54
-5
deploy/python/mot_jde_infer.py
deploy/python/mot_jde_infer.py
+52
-25
deploy/python/mot_keypoint_unite_infer.py
deploy/python/mot_keypoint_unite_infer.py
+5
-5
deploy/python/mot_sde_infer.py
deploy/python/mot_sde_infer.py
+269
-78
deploy/python/tracker_config.yml
deploy/python/tracker_config.yml
+22
-6
ppdet/modeling/mot/tracker/jde_tracker.py
ppdet/modeling/mot/tracker/jde_tracker.py
+10
-13
未找到文件。
configs/mot/jde/_base_/jde_darknet53.yml
浏览文件 @
822c7933
...
...
@@ -53,4 +53,4 @@ JDETracker:
det_thresh
:
0.3
track_buffer
:
30
min_box_area
:
200
motion
:
KalmanFilter
vertical_ratio
:
1.6
# for pedestrian
deploy/pphuman/config/tracker_config.yml
浏览文件 @
822c7933
...
...
@@ -11,8 +11,8 @@ JDETracker:
conf_thres
:
0.6
low_conf_thres
:
0.1
match_thres
:
0.9
min_box_area
:
10
0
vertical_ratio
:
1.6
#
for pedestrian
min_box_area
:
0
vertical_ratio
:
0
# 1.6
for pedestrian
DeepSORTTracker
:
input_size
:
[
64
,
192
]
...
...
deploy/pptracking/python/mot/tracker/jde_tracker.py
浏览文件 @
822c7933
...
...
@@ -38,7 +38,7 @@ class JDETracker(object):
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <0 means no need to filter bboxes,usually set
bad results. If set <
=
0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
...
...
@@ -64,8 +64,8 @@ class JDETracker(object):
num_classes
=
1
,
det_thresh
=
0.3
,
track_buffer
=
30
,
min_box_area
=
20
0
,
vertical_ratio
=
1.6
,
min_box_area
=
0
,
vertical_ratio
=
0
,
tracked_thresh
=
0.7
,
r_tracked_thresh
=
0.5
,
unconfirmed_thresh
=
0.7
,
...
...
@@ -161,9 +161,8 @@ class JDETracker(object):
detections
=
[
STrack
(
STrack
.
tlbr_to_tlwh
(
tlbrs
[
2
:
6
]),
tlbrs
[
1
],
cls_id
,
30
,
temp_feat
)
for
(
tlbrs
,
temp_feat
)
in
zip
(
pred_dets_cls
,
pred_embs_cls
)
30
,
temp_feat
)
for
(
tlbrs
,
temp_feat
)
in
zip
(
pred_dets_cls
,
pred_embs_cls
)
]
else
:
detections
=
[]
...
...
@@ -238,15 +237,13 @@ class JDETracker(object):
for
tlbrs
in
pred_dets_cls_second
]
else
:
pred_embs_cls_second
=
pred_embs_dict
[
cls_id
][
inds_second
]
pred_embs_cls_second
=
pred_embs_dict
[
cls_id
][
inds_second
]
detections_second
=
[
STrack
(
STrack
.
tlbr_to_tlwh
(
tlbrs
[
2
:
6
]),
tlbrs
[
1
],
cls_id
,
30
,
temp_feat
)
for
(
tlbrs
,
temp_feat
)
in
zip
(
pred_dets_cls_second
,
pred_embs_cls_second
)
STrack
.
tlbr_to_tlwh
(
tlbrs
[
2
:
6
]),
tlbrs
[
1
],
cls_id
,
30
,
temp_feat
)
for
(
tlbrs
,
temp_feat
)
in
zip
(
pred_dets_cls_second
,
pred_embs_cls_second
)
]
else
:
detections_second
=
[]
...
...
deploy/pptracking/python/mot_jde_infer.py
浏览文件 @
822c7933
...
...
@@ -112,8 +112,8 @@ class JDE_Detector(Detector):
# tracker config
assert
self
.
pred_config
.
tracker
,
"The exported JDE Detector model should have tracker."
cfg
=
self
.
pred_config
.
tracker
min_box_area
=
cfg
.
get
(
'min_box_area'
,
20
0
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
1.6
)
min_box_area
=
cfg
.
get
(
'min_box_area'
,
0.
0
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
0.0
)
conf_thres
=
cfg
.
get
(
'conf_thres'
,
0.0
)
tracked_thresh
=
cfg
.
get
(
'tracked_thresh'
,
0.7
)
metric_type
=
cfg
.
get
(
'metric_type'
,
'euclidean'
)
...
...
@@ -164,7 +164,7 @@ class JDE_Detector(Detector):
repeats (int): repeats number for prediction
Returns:
result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
matix element:[
x_min, y_min, x_max, y_max, score, class
]
matix element:[
class, score, x_min, y_min, x_max, y_max
]
FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
shape: [N, 128]
'''
...
...
deploy/pptracking/python/mot_sde_infer.py
浏览文件 @
822c7933
...
...
@@ -165,8 +165,8 @@ class SDE_Detector(Detector):
# use ByteTracker
use_byte
=
cfg
.
get
(
'use_byte'
,
False
)
det_thresh
=
cfg
.
get
(
'det_thresh'
,
0.3
)
min_box_area
=
cfg
.
get
(
'min_box_area'
,
20
0
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
1.6
)
min_box_area
=
cfg
.
get
(
'min_box_area'
,
0
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
0
)
match_thres
=
cfg
.
get
(
'match_thres'
,
0.9
)
conf_thres
=
cfg
.
get
(
'conf_thres'
,
0.6
)
low_conf_thres
=
cfg
.
get
(
'low_conf_thres'
,
0.1
)
...
...
@@ -194,7 +194,7 @@ class SDE_Detector(Detector):
return
result
def
reidprocess
(
self
,
det_results
,
repeats
=
1
):
pred_dets
=
det_results
[
'boxes'
]
pred_dets
=
det_results
[
'boxes'
]
# cls_id, score, x0, y0, x1, y1
pred_xyxys
=
pred_dets
[:,
2
:
6
]
ori_image
=
det_results
[
'ori_image'
]
...
...
@@ -234,7 +234,7 @@ class SDE_Detector(Detector):
return
det_results
def
tracking
(
self
,
det_results
):
pred_dets
=
det_results
[
'boxes'
]
pred_dets
=
det_results
[
'boxes'
]
# cls_id, score, x0, y0, x1, y1
pred_embs
=
det_results
.
get
(
'embeddings'
,
None
)
if
self
.
use_deepsort_tracker
:
...
...
deploy/pptracking/python/tracker_config.yml
浏览文件 @
822c7933
...
...
@@ -11,8 +11,8 @@ JDETracker:
conf_thres
:
0.6
low_conf_thres
:
0.1
match_thres
:
0.9
min_box_area
:
10
0
vertical_ratio
:
1.6
#
for pedestrian
min_box_area
:
0
vertical_ratio
:
0
# 1.6
for pedestrian
DeepSORTTracker
:
input_size
:
[
64
,
192
]
...
...
deploy/python/README.md
浏览文件 @
822c7933
...
...
@@ -3,27 +3,76 @@
在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 预测引擎使用了AnalysisPredictor,专门针对推理进行了优化,是基于
[
C++预测库
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html
)
的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。
主要包含两个步骤:
Python端预测部署主要包含两个步骤:
-
导出预测模型
-
基于Python进行预测
## 1. 导出预测模型
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:
[
导出模型
](
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md
)
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:
[
导出模型
](
../deploy/EXPORT_MODEL.md
)
,例如
```
bash
# 导出YOLOv3检测模型
python tools/export_model.py
-c
configs/yolov3/yolov3_darknet53_270e_coco.yml
--output_dir
=
./inference_model
\
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolov3_darknet53_270e_coco.pdparams
# 导出HigherHRNet(bottom-up)关键点检测模型
python tools/export_model.py
-c
configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams
# 导出HRNet(top-down)关键点检测模型
python tools/export_model.py
-c
configs/keypoint/hrnet/hrnet_w32_384x288.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_384x288.pdparams
# 导出FairMOT多目标跟踪模型
python tools/export_model.py
-c
configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams
# 导出ByteTrack多目标跟踪模型(相当于只导出检测器)
python tools/export_model.py
-c
configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
```
导出后目录下,包括
`infer_cfg.yml`
,
`model.pdiparams`
,
`model.pdiparams.info`
,
`model.pdmodel`
四个文件。
## 2. 基于Python的预测
## 2. 基于Python的预测
### 2.1 通用检测
在终端输入以下命令进行预测:
```
bash
python deploy/python/infer.py
--model_dir
=
./output_inference/yolov3_darknet53_270e_coco
--image_file
=
./demo/000000014439.jpg
--device
=
GPU
```
### 2.2 关键点检测
在终端输入以下命令进行预测:
```
bash
# keypoint top-down(HRNet)/bottom-up(HigherHRNet)单独推理,该模式下top-down模型HRNet只支持单人截图预测
python deploy/python/keypoint_infer.py
--model_dir
=
output_inference/hrnet_w32_384x288/
--image_file
=
./demo/hrnet_demo.jpg
--device
=
GPU
--threshold
=
0.5
python deploy/python/keypoint_infer.py
--model_dir
=
output_inference/higherhrnet_hrnet_w32_512/
--image_file
=
./demo/000000014439_640x640.jpg
--device
=
GPU
--threshold
=
0.5
# detector 检测 + keypoint top-down模型联合部署(联合推理只支持top-down关键点模型)
python deploy/python/det_keypoint_unite_infer.py
--det_model_dir
=
output_inference/yolov3_darknet53_270e_coco/
--keypoint_model_dir
=
output_inference/hrnet_w32_384x288/
--video_file
={
your video name
}
.mp4
--device
=
GPU
```
**注意:**
-
关键点检测模型导出和预测具体可参照
[
keypoint
](
../../configs/keypoint/README.md
)
,可分别在各个模型的文档中查找具体用法;
-
此目录下的关键点检测部署为基础前向功能,更多关键点检测功能可使用PP-Human项目,参照
[
pphuman
](
../pphuman/README.md
)
;
### 2.3 多目标跟踪
在终端输入以下命令进行预测:
```
bash
python deploy/python/infer.py
--model_dir
=
./output_inference/yolov3_mobilenet_v1_roadsign
--image_file
=
./demo/road554.png
--device
=
GPU
# FairMOT跟踪
python deploy/python/mot_jde_infer.py
--model_dir
=
output_inference/fairmot_dla34_30e_1088x608
--video_file
={
your video name
}
.mp4
--device
=
GPU
# ByteTrack跟踪
python deploy/python/mot_sde_infer.py
--model_dir
=
output_inference/ppyoloe_crn_l_36e_640x640_mot17half/
--tracker_config
=
deploy/python/tracker_config.yml
--video_file
={
your video name
}
.mp4
--device
=
GPU
--scaled
=
True
# FairMOT多目标跟踪联合HRNet关键点检测(联合推理只支持top-down关键点模型)
python deploy/python/mot_keypoint_unite_infer.py
--mot_model_dir
=
output_inference/fairmot_dla34_30e_1088x608/
--keypoint_model_dir
=
output_inference/hrnet_w32_384x288/
--video_file
={
your video name
}
.mp4
--device
=
GPU
```
**注意:**
-
多目标跟踪模型导出和预测具体可参照
[
mot]
](
../../configs/mot/README.md
)
,可分别在各个模型的文档中查找具体用法;
-
此目录下的跟踪部署为基础前向功能以及联合关键点部署,更多跟踪功能可使用PP-Human项目,参照
[
pphuman
](
../pphuman/README.md
)
,或PP-Tracking项目(绘制轨迹、出入口流量计数),参照
[
pptracking
](
../pptracking/README.md
)
;
参数说明如下:
| 参数 | 是否必须|含义 |
...
...
deploy/python/mot_jde_infer.py
浏览文件 @
822c7933
...
...
@@ -32,7 +32,7 @@ sys.path.insert(0, parent_path)
from
pptracking.python.mot
import
JDETracker
from
pptracking.python.mot.utils
import
MOTTimer
,
write_mot_results
from
pptracking.python.
visualize
import
plot_tracking
,
plot_tracking_dict
from
pptracking.python.
mot.visualize
import
plot_tracking_dict
# Global dictionary
MOT_JDE_SUPPORT_MODELS
=
{
...
...
@@ -55,9 +55,14 @@ class JDE_Detector(Detector):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
"""
def
__init__
(
self
,
def
__init__
(
self
,
model_dir
,
tracker_config
=
None
,
device
=
'CPU'
,
...
...
@@ -70,7 +75,9 @@ class JDE_Detector(Detector):
cpu_threads
=
1
,
enable_mkldnn
=
False
,
output_dir
=
'output'
,
threshold
=
0.5
):
threshold
=
0.5
,
save_images
=
False
,
save_mot_txts
=
False
,
):
super
(
JDE_Detector
,
self
).
__init__
(
model_dir
=
model_dir
,
device
=
device
,
...
...
@@ -84,6 +91,8 @@ class JDE_Detector(Detector):
enable_mkldnn
=
enable_mkldnn
,
output_dir
=
output_dir
,
threshold
=
threshold
,
)
self
.
save_images
=
save_images
self
.
save_mot_txts
=
save_mot_txts
assert
batch_size
==
1
,
"MOT model only supports batch_size=1."
self
.
det_times
=
Timer
(
with_tracker
=
True
)
self
.
num_classes
=
len
(
self
.
pred_config
.
labels
)
...
...
@@ -91,8 +100,8 @@ class JDE_Detector(Detector):
# tracker config
assert
self
.
pred_config
.
tracker
,
"The exported JDE Detector model should have tracker."
cfg
=
self
.
pred_config
.
tracker
min_box_area
=
cfg
.
get
(
'min_box_area'
,
20
0
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
1.6
)
min_box_area
=
cfg
.
get
(
'min_box_area'
,
0.
0
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
0.0
)
conf_thres
=
cfg
.
get
(
'conf_thres'
,
0.0
)
tracked_thresh
=
cfg
.
get
(
'tracked_thresh'
,
0.7
)
metric_type
=
cfg
.
get
(
'metric_type'
,
'euclidean'
)
...
...
@@ -115,7 +124,7 @@ class JDE_Detector(Detector):
return
result
def
tracking
(
self
,
det_results
):
pred_dets
=
det_results
[
'pred_dets'
]
#
'cls_id, score, x0, y0, x1, y1'
pred_dets
=
det_results
[
'pred_dets'
]
#
cls_id, score, x0, y0, x1, y1
pred_embs
=
det_results
[
'pred_embs'
]
online_targets_dict
=
self
.
tracker
.
update
(
pred_dets
,
pred_embs
)
...
...
@@ -164,7 +173,8 @@ class JDE_Detector(Detector):
image_list
,
run_benchmark
=
False
,
repeats
=
1
,
visual
=
True
):
visual
=
True
,
seq_name
=
None
):
mot_results
=
[]
num_classes
=
self
.
num_classes
image_list
.
sort
()
...
...
@@ -225,7 +235,7 @@ class JDE_Detector(Detector):
self
.
det_times
.
img_num
+=
1
if
visual
:
if
frame_id
%
10
==
0
:
if
len
(
image_list
)
>
1
and
frame_id
%
10
==
0
:
print
(
'Tracking frame {}'
.
format
(
frame_id
))
frame
,
_
=
decode_image
(
img_file
,
{})
...
...
@@ -237,6 +247,7 @@ class JDE_Detector(Detector):
online_scores
,
frame_id
=
frame_id
,
ids2names
=
ids2names
)
if
seq_name
is
None
:
seq_name
=
image_list
[
0
].
split
(
'/'
)[
-
2
]
save_dir
=
os
.
path
.
join
(
self
.
output_dir
,
seq_name
)
if
not
os
.
path
.
exists
(
save_dir
):
...
...
@@ -264,7 +275,8 @@ class JDE_Detector(Detector):
if
not
os
.
path
.
exists
(
self
.
output_dir
):
os
.
makedirs
(
self
.
output_dir
)
out_path
=
os
.
path
.
join
(
self
.
output_dir
,
video_out_name
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'mp4v'
)
video_format
=
'mp4v'
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
video_format
)
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
frame_id
=
1
...
...
@@ -282,7 +294,9 @@ class JDE_Detector(Detector):
frame_id
+=
1
timer
.
tic
()
mot_results
=
self
.
predict_image
([
frame
],
visual
=
False
)
seq_name
=
video_out_name
.
split
(
'.'
)[
0
]
mot_results
=
self
.
predict_image
(
[
frame
],
visual
=
False
,
seq_name
=
seq_name
)
timer
.
toc
()
online_tlwhs
,
online_scores
,
online_ids
=
mot_results
[
0
]
...
...
@@ -307,20 +321,33 @@ class JDE_Detector(Detector):
cv2
.
imshow
(
'Mask Detection'
,
im
)
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
if
self
.
save_mot_txts
:
result_filename
=
os
.
path
.
join
(
self
.
output_dir
,
video_out_name
.
split
(
'.'
)[
-
2
]
+
'.txt'
)
write_mot_results
(
result_filename
,
results
,
data_type
,
num_classes
)
writer
.
release
()
def
main
():
detector
=
JDE_Detector
(
FLAGS
.
model_dir
,
tracker_config
=
None
,
device
=
FLAGS
.
device
,
run_mode
=
FLAGS
.
run_mode
,
batch_size
=
1
,
trt_min_shape
=
FLAGS
.
trt_min_shape
,
trt_max_shape
=
FLAGS
.
trt_max_shape
,
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
cpu_threads
=
FLAGS
.
cpu_threads
,
enable_mkldnn
=
FLAGS
.
enable_mkldnn
)
enable_mkldnn
=
FLAGS
.
enable_mkldnn
,
output_dir
=
FLAGS
.
output_dir
,
threshold
=
FLAGS
.
threshold
,
save_images
=
FLAGS
.
save_images
,
save_mot_txts
=
FLAGS
.
save_mot_txts
)
# predict from video file or camera video stream
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
...
...
deploy/python/mot_keypoint_unite_infer.py
浏览文件 @
822c7933
...
...
@@ -24,7 +24,7 @@ from collections import defaultdict
from
mot_keypoint_unite_utils
import
argsparser
from
preprocess
import
decode_image
from
infer
import
print_arguments
,
get_test_images
from
infer
import
print_arguments
,
get_test_images
,
bench_log
from
mot_sde_infer
import
SDE_Detector
from
mot_jde_infer
import
JDE_Detector
,
MOT_JDE_SUPPORT_MODELS
from
keypoint_infer
import
KeyPointDetector
,
KEYPOINT_SUPPORT_MODELS
...
...
@@ -39,7 +39,7 @@ import sys
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
]
*
2
)))
sys
.
path
.
insert
(
0
,
parent_path
)
from
pptracking.python.visualize
import
plot_tracking
,
plot_tracking_dict
from
pptracking.python.
mot.
visualize
import
plot_tracking
,
plot_tracking_dict
from
pptracking.python.mot.utils
import
MOTTimer
as
FPSTimer
...
...
@@ -92,7 +92,7 @@ def mot_topdown_unite_predict(mot_detector,
keypoint_res
=
predict_with_given_det
(
image
,
results
,
topdown_keypoint_detector
,
keypoint_batch_size
,
FLAGS
.
mot_threshold
,
FLAGS
.
keypoint_threshold
,
FLAGS
.
run_benchmark
)
FLAGS
.
run_benchmark
)
if
save_res
:
store_res
.
append
([
...
...
@@ -146,7 +146,7 @@ def mot_topdown_unite_predict_video(mot_detector,
if
not
os
.
path
.
exists
(
FLAGS
.
output_dir
):
os
.
makedirs
(
FLAGS
.
output_dir
)
out_path
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'mp4v'
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'mp4v'
)
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
frame_id
=
0
timer_mot
,
timer_kp
,
timer_mot_kp
=
FPSTimer
(),
FPSTimer
(),
FPSTimer
()
...
...
@@ -179,7 +179,7 @@ def mot_topdown_unite_predict_video(mot_detector,
timer_kp
.
tic
()
keypoint_res
=
predict_with_given_det
(
frame
,
results
,
topdown_keypoint_detector
,
keypoint_batch_size
,
FLAGS
.
mot_threshold
,
FLAGS
.
keypoint_threshold
,
FLAGS
.
run_benchmark
)
FLAGS
.
run_benchmark
)
timer_kp
.
toc
()
timer_mot_kp
.
toc
()
...
...
deploy/python/mot_sde_infer.py
浏览文件 @
822c7933
# Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -23,15 +23,15 @@ import paddle
from
benchmark_utils
import
PaddleInferBenchmark
from
preprocess
import
decode_image
from
utils
import
argsparser
,
Timer
,
get_current_memory_mb
from
infer
import
Detector
,
get_test_images
,
print_arguments
,
bench_log
,
PredictConfig
from
infer
import
Detector
,
get_test_images
,
print_arguments
,
bench_log
,
PredictConfig
,
load_predictor
# add python path
import
sys
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
]
*
2
)))
sys
.
path
.
insert
(
0
,
parent_path
)
from
pptracking.python.mot
import
JDETracker
from
pptracking.python.mot.utils
import
MOTTimer
,
write_mot_results
from
pptracking.python.mot
import
JDETracker
,
DeepSORTTracker
from
pptracking.python.mot.utils
import
MOTTimer
,
write_mot_results
,
get_crops
,
clip_box
from
pptracking.python.mot.visualize
import
plot_tracking
,
plot_tracking_dict
...
...
@@ -50,7 +50,11 @@ class SDE_Detector(Detector):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
use_dark(bool): whether to use postprocess in DarkPose
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
"""
def
__init__
(
self
,
...
...
@@ -66,7 +70,10 @@ class SDE_Detector(Detector):
cpu_threads
=
1
,
enable_mkldnn
=
False
,
output_dir
=
'output'
,
threshold
=
0.5
):
threshold
=
0.5
,
save_images
=
False
,
save_mot_txts
=
False
,
reid_model_dir
=
None
):
super
(
SDE_Detector
,
self
).
__init__
(
model_dir
=
model_dir
,
device
=
device
,
...
...
@@ -80,37 +87,163 @@ class SDE_Detector(Detector):
enable_mkldnn
=
enable_mkldnn
,
output_dir
=
output_dir
,
threshold
=
threshold
,
)
self
.
save_images
=
save_images
self
.
save_mot_txts
=
save_mot_txts
assert
batch_size
==
1
,
"MOT model only supports batch_size=1."
self
.
det_times
=
Timer
(
with_tracker
=
True
)
self
.
num_classes
=
len
(
self
.
pred_config
.
labels
)
# tracker config
# reid config
self
.
use_reid
=
False
if
reid_model_dir
is
None
else
True
if
self
.
use_reid
:
self
.
reid_pred_config
=
self
.
set_config
(
reid_model_dir
)
self
.
reid_predictor
,
self
.
config
=
load_predictor
(
reid_model_dir
,
run_mode
=
run_mode
,
batch_size
=
50
,
# reid_batch_size
min_subgraph_size
=
self
.
reid_pred_config
.
min_subgraph_size
,
device
=
device
,
use_dynamic_shape
=
self
.
reid_pred_config
.
use_dynamic_shape
,
trt_min_shape
=
trt_min_shape
,
trt_max_shape
=
trt_max_shape
,
trt_opt_shape
=
trt_opt_shape
,
trt_calib_mode
=
trt_calib_mode
,
cpu_threads
=
cpu_threads
,
enable_mkldnn
=
enable_mkldnn
)
else
:
self
.
reid_pred_config
=
None
self
.
reid_predictor
=
None
assert
tracker_config
is
not
None
,
'Note that tracker_config should be set.'
self
.
tracker_config
=
tracker_config
cfg
=
yaml
.
safe_load
(
open
(
self
.
tracker_config
))[
'tracker'
]
min_box_area
=
cfg
.
get
(
'min_box_area'
,
200
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
1.6
)
use_byte
=
cfg
.
get
(
'use_byte'
,
True
)
tracker_cfg
=
yaml
.
safe_load
(
open
(
self
.
tracker_config
))
cfg
=
tracker_cfg
[
tracker_cfg
[
'type'
]]
# tracker config
self
.
use_deepsort_tracker
=
True
if
tracker_cfg
[
'type'
]
==
'DeepSORTTracker'
else
False
if
self
.
use_deepsort_tracker
:
# use DeepSORTTracker
if
self
.
reid_pred_config
is
not
None
and
hasattr
(
self
.
reid_pred_config
,
'tracker'
):
cfg
=
self
.
reid_pred_config
.
tracker
budget
=
cfg
.
get
(
'budget'
,
100
)
max_age
=
cfg
.
get
(
'max_age'
,
30
)
max_iou_distance
=
cfg
.
get
(
'max_iou_distance'
,
0.7
)
matching_threshold
=
cfg
.
get
(
'matching_threshold'
,
0.2
)
min_box_area
=
cfg
.
get
(
'min_box_area'
,
0
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
0
)
self
.
tracker
=
DeepSORTTracker
(
budget
=
budget
,
max_age
=
max_age
,
max_iou_distance
=
max_iou_distance
,
matching_threshold
=
matching_threshold
,
min_box_area
=
min_box_area
,
vertical_ratio
=
vertical_ratio
,
)
else
:
# use ByteTracker
use_byte
=
cfg
.
get
(
'use_byte'
,
False
)
det_thresh
=
cfg
.
get
(
'det_thresh'
,
0.3
)
min_box_area
=
cfg
.
get
(
'min_box_area'
,
0
)
vertical_ratio
=
cfg
.
get
(
'vertical_ratio'
,
0
)
match_thres
=
cfg
.
get
(
'match_thres'
,
0.9
)
conf_thres
=
cfg
.
get
(
'conf_thres'
,
0.6
)
low_conf_thres
=
cfg
.
get
(
'low_conf_thres'
,
0.1
)
self
.
tracker
=
JDETracker
(
use_byte
=
use_byte
,
det_thresh
=
det_thresh
,
num_classes
=
self
.
num_classes
,
min_box_area
=
min_box_area
,
vertical_ratio
=
vertical_ratio
,
match_thres
=
match_thres
,
conf_thres
=
conf_thres
,
low_conf_thres
=
low_conf_thres
)
low_conf_thres
=
low_conf_thres
,
)
def
postprocess
(
self
,
inputs
,
result
):
# postprocess output of predictor
np_boxes_num
=
result
[
'boxes_num'
]
if
np_boxes_num
[
0
]
<=
0
:
print
(
'[WARNNING] No object detected.'
)
result
=
{
'boxes'
:
np
.
zeros
([
0
,
6
]),
'boxes_num'
:
[
0
]}
result
=
{
k
:
v
for
k
,
v
in
result
.
items
()
if
v
is
not
None
}
return
result
def
reidprocess
(
self
,
det_results
,
repeats
=
1
):
pred_dets
=
det_results
[
'boxes'
]
pred_xyxys
=
pred_dets
[:,
2
:
6
]
ori_image
=
det_results
[
'ori_image'
]
ori_image_shape
=
ori_image
.
shape
[:
2
]
pred_xyxys
,
keep_idx
=
clip_box
(
pred_xyxys
,
ori_image_shape
)
if
len
(
keep_idx
[
0
])
==
0
:
det_results
[
'boxes'
]
=
np
.
zeros
((
1
,
6
),
dtype
=
np
.
float32
)
det_results
[
'embeddings'
]
=
None
return
det_results
pred_dets
=
pred_dets
[
keep_idx
[
0
]]
pred_xyxys
=
pred_dets
[:,
2
:
6
]
w
,
h
=
self
.
tracker
.
input_size
crops
=
get_crops
(
pred_xyxys
,
ori_image
,
w
,
h
)
# to keep fast speed, only use topk crops
crops
=
crops
[:
50
]
# reid_batch_size
det_results
[
'crops'
]
=
np
.
array
(
crops
).
astype
(
'float32'
)
det_results
[
'boxes'
]
=
pred_dets
[:
50
]
input_names
=
self
.
reid_predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
reid_predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
det_results
[
input_names
[
i
]])
# model prediction
for
i
in
range
(
repeats
):
self
.
reid_predictor
.
run
()
output_names
=
self
.
reid_predictor
.
get_output_names
()
feature_tensor
=
self
.
reid_predictor
.
get_output_handle
(
output_names
[
0
])
pred_embs
=
feature_tensor
.
copy_to_cpu
()
det_results
[
'embeddings'
]
=
pred_embs
return
det_results
def
tracking
(
self
,
det_results
):
pred_dets
=
det_results
[
'boxes'
]
# 'cls_id, score, x0, y0, x1, y1'
pred_embs
=
None
pred_embs
=
det_results
.
get
(
'embeddings'
,
None
)
online_targets_dict
=
self
.
tracker
.
update
(
pred_dets
,
pred_embs
)
if
self
.
use_deepsort_tracker
:
# use DeepSORTTracker, only support singe class
self
.
tracker
.
predict
()
online_targets
=
self
.
tracker
.
update
(
pred_dets
,
pred_embs
)
online_tlwhs
,
online_scores
,
online_ids
=
[],
[],
[]
for
t
in
online_targets
:
if
not
t
.
is_confirmed
()
or
t
.
time_since_update
>
1
:
continue
tlwh
=
t
.
to_tlwh
()
tscore
=
t
.
score
tid
=
t
.
track_id
if
self
.
tracker
.
vertical_ratio
>
0
and
tlwh
[
2
]
/
tlwh
[
3
]
>
self
.
tracker
.
vertical_ratio
:
continue
online_tlwhs
.
append
(
tlwh
)
online_scores
.
append
(
tscore
)
online_ids
.
append
(
tid
)
tracking_outs
=
{
'online_tlwhs'
:
online_tlwhs
,
'online_scores'
:
online_scores
,
'online_ids'
:
online_ids
,
}
return
tracking_outs
else
:
# use ByteTracker, support multiple class
online_tlwhs
=
defaultdict
(
list
)
online_scores
=
defaultdict
(
list
)
online_ids
=
defaultdict
(
list
)
online_targets_dict
=
self
.
tracker
.
update
(
pred_dets
,
pred_embs
)
for
cls_id
in
range
(
self
.
num_classes
):
online_targets
=
online_targets_dict
[
cls_id
]
for
t
in
online_targets
:
...
...
@@ -126,19 +259,26 @@ class SDE_Detector(Detector):
online_ids
[
cls_id
].
append
(
tid
)
online_scores
[
cls_id
].
append
(
tscore
)
return
online_tlwhs
,
online_scores
,
online_ids
tracking_outs
=
{
'online_tlwhs'
:
online_tlwhs
,
'online_scores'
:
online_scores
,
'online_ids'
:
online_ids
,
}
return
tracking_outs
def
predict_image
(
self
,
image_list
,
run_benchmark
=
False
,
repeats
=
1
,
visual
=
True
):
mot_results
=
[]
visual
=
True
,
seq_name
=
None
):
num_classes
=
self
.
num_classes
image_list
.
sort
()
ids2names
=
self
.
pred_config
.
labels
mot_results
=
[]
for
frame_id
,
img_file
in
enumerate
(
image_list
):
batch_image_list
=
[
img_file
]
# bs=1 in MOT model
frame
,
_
=
decode_image
(
img_file
,
{})
if
run_benchmark
:
# preprocess
inputs
=
self
.
preprocess
(
batch_image_list
)
# warmup
...
...
@@ -159,10 +299,16 @@ class SDE_Detector(Detector):
self
.
det_times
.
postprocess_time_s
.
end
()
# tracking
if
self
.
use_reid
:
det_result
[
'frame_id'
]
=
frame_id
det_result
[
'seq_name'
]
=
seq_name
det_result
[
'ori_image'
]
=
frame
det_result
=
self
.
reidprocess
(
det_result
)
result_warmup
=
self
.
tracking
(
det_result
)
self
.
det_times
.
tracking_time_s
.
start
()
online_tlwhs
,
online_scores
,
online_ids
=
self
.
tracking
(
det_result
)
if
self
.
use_reid
:
det_result
=
self
.
reidprocess
(
det_result
)
tracking_outs
=
self
.
tracking
(
det_result
)
self
.
det_times
.
tracking_time_s
.
end
()
self
.
det_times
.
img_num
+=
1
...
...
@@ -186,16 +332,26 @@ class SDE_Detector(Detector):
# tracking process
self
.
det_times
.
tracking_time_s
.
start
()
online_tlwhs
,
online_scores
,
online_ids
=
self
.
tracking
(
det_result
)
if
self
.
use_reid
:
det_result
[
'frame_id'
]
=
frame_id
det_result
[
'seq_name'
]
=
seq_name
det_result
[
'ori_image'
]
=
frame
det_result
=
self
.
reidprocess
(
det_result
)
tracking_outs
=
self
.
tracking
(
det_result
)
self
.
det_times
.
tracking_time_s
.
end
()
self
.
det_times
.
img_num
+=
1
online_tlwhs
=
tracking_outs
[
'online_tlwhs'
]
online_scores
=
tracking_outs
[
'online_scores'
]
online_ids
=
tracking_outs
[
'online_ids'
]
mot_results
.
append
([
online_tlwhs
,
online_scores
,
online_ids
])
if
visual
:
if
frame_id
%
10
==
0
:
if
len
(
image_list
)
>
1
and
frame_id
%
10
==
0
:
print
(
'Tracking frame {}'
.
format
(
frame_id
))
frame
,
_
=
decode_image
(
img_file
,
{})
if
isinstance
(
online_tlwhs
,
defaultdict
):
im
=
plot_tracking_dict
(
frame
,
num_classes
,
...
...
@@ -204,14 +360,19 @@ class SDE_Detector(Detector):
online_scores
,
frame_id
=
frame_id
,
ids2names
=
[])
seq_name
=
image_list
[
0
].
split
(
'/'
)[
-
2
]
else
:
im
=
plot_tracking
(
frame
,
online_tlwhs
,
online_ids
,
online_scores
,
frame_id
=
frame_id
)
save_dir
=
os
.
path
.
join
(
self
.
output_dir
,
seq_name
)
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
cv2
.
imwrite
(
os
.
path
.
join
(
save_dir
,
'{:05d}.jpg'
.
format
(
frame_id
)),
im
)
mot_results
.
append
([
online_tlwhs
,
online_scores
,
online_ids
])
return
mot_results
def
predict_video
(
self
,
video_file
,
camera_id
):
...
...
@@ -231,13 +392,17 @@ class SDE_Detector(Detector):
if
not
os
.
path
.
exists
(
self
.
output_dir
):
os
.
makedirs
(
self
.
output_dir
)
out_path
=
os
.
path
.
join
(
self
.
output_dir
,
video_out_name
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'mp4v'
)
video_format
=
'mp4v'
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
video_format
)
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
frame_id
=
1
timer
=
MOTTimer
()
results
=
defaultdict
(
list
)
# support single class and multi classes
results
=
defaultdict
(
list
)
num_classes
=
self
.
num_classes
data_type
=
'mcmot'
if
num_classes
>
1
else
'mot'
ids2names
=
self
.
pred_config
.
labels
while
(
1
):
ret
,
frame
=
capture
.
read
()
if
not
ret
:
...
...
@@ -247,16 +412,32 @@ class SDE_Detector(Detector):
frame_id
+=
1
timer
.
tic
()
mot_results
=
self
.
predict_image
([
frame
],
visual
=
False
)
seq_name
=
video_out_name
.
split
(
'.'
)[
0
]
mot_results
=
self
.
predict_image
(
[
frame
],
visual
=
False
,
seq_name
=
seq_name
)
timer
.
toc
()
# bs=1 in MOT model
online_tlwhs
,
online_scores
,
online_ids
=
mot_results
[
0
]
for
cls_id
in
range
(
num_classes
):
results
[
cls_id
].
append
(
(
frame_id
+
1
,
online_tlwhs
[
cls_id
],
online_scores
[
cls_id
],
online_ids
[
cls_id
]))
fps
=
1.
/
timer
.
duration
if
self
.
use_deepsort_tracker
:
# use DeepSORTTracker, only support singe class
results
[
0
].
append
(
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
))
im
=
plot_tracking
(
frame
,
online_tlwhs
,
online_ids
,
online_scores
,
frame_id
=
frame_id
,
fps
=
fps
)
else
:
# use ByteTracker, support multiple class
for
cls_id
in
range
(
num_classes
):
results
[
cls_id
].
append
(
(
frame_id
+
1
,
online_tlwhs
[
cls_id
],
online_scores
[
cls_id
],
online_ids
[
cls_id
]))
im
=
plot_tracking_dict
(
frame
,
num_classes
,
...
...
@@ -265,13 +446,19 @@ class SDE_Detector(Detector):
online_scores
,
frame_id
=
frame_id
,
fps
=
fps
,
ids2names
=
[]
)
ids2names
=
ids2names
)
writer
.
write
(
im
)
if
camera_id
!=
-
1
:
cv2
.
imshow
(
'Mask Detection'
,
im
)
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
if
self
.
save_mot_txts
:
result_filename
=
os
.
path
.
join
(
self
.
output_dir
,
video_out_name
.
split
(
'.'
)[
-
2
]
+
'.txt'
)
write_mot_results
(
result_filename
,
results
)
writer
.
release
()
...
...
@@ -282,18 +469,20 @@ def main():
arch
=
yml_conf
[
'arch'
]
detector
=
SDE_Detector
(
FLAGS
.
model_dir
,
FLAGS
.
tracker_config
,
tracker_config
=
FLAGS
.
tracker_config
,
device
=
FLAGS
.
device
,
run_mode
=
FLAGS
.
run_mode
,
batch_size
=
FLAGS
.
batch_size
,
batch_size
=
1
,
trt_min_shape
=
FLAGS
.
trt_min_shape
,
trt_max_shape
=
FLAGS
.
trt_max_shape
,
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
cpu_threads
=
FLAGS
.
cpu_threads
,
enable_mkldnn
=
FLAGS
.
enable_mkldnn
,
output_dir
=
FLAGS
.
output_dir
,
threshold
=
FLAGS
.
threshold
,
output_dir
=
FLAGS
.
output_dir
)
save_images
=
FLAGS
.
save_images
,
save_mot_txts
=
FLAGS
.
save_mot_txts
,
)
# predict from video file or camera video stream
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
...
...
@@ -303,7 +492,9 @@ def main():
if
FLAGS
.
image_dir
is
None
and
FLAGS
.
image_file
is
not
None
:
assert
FLAGS
.
batch_size
==
1
,
"--batch_size should be 1 in MOT models."
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
detector
.
predict_image
(
img_list
,
FLAGS
.
run_benchmark
,
repeats
=
10
)
seq_name
=
FLAGS
.
image_dir
.
split
(
'/'
)[
-
1
]
detector
.
predict_image
(
img_list
,
FLAGS
.
run_benchmark
,
repeats
=
10
,
seq_name
=
seq_name
)
if
not
FLAGS
.
run_benchmark
:
detector
.
det_times
.
info
(
average
=
True
)
...
...
deploy/python/tracker_config.yml
浏览文件 @
822c7933
# config of tracker for MOT SDE Detector, use
ByteTracker
as default.
# The tracker of MOT JDE Detector is exported together with the model.
# config of tracker for MOT SDE Detector, use
'JDETracker'
as default.
# The tracker of MOT JDE Detector
(such as FairMOT)
is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
tracker
:
use_byte
:
true
type
:
JDETracker
# 'JDETracker' or 'DeepSORTTracker'
# BYTETracker
JDETracker
:
use_byte
:
True
det_thresh
:
0.3
conf_thres
:
0.6
low_conf_thres
:
0.1
match_thres
:
0.9
min_box_area
:
100
vertical_ratio
:
1.6
min_box_area
:
0
vertical_ratio
:
0
# 1.6 for pedestrian
DeepSORTTracker
:
input_size
:
[
64
,
192
]
min_box_area
:
0
vertical_ratio
:
-1
budget
:
100
max_age
:
70
n_init
:
3
metric_type
:
cosine
matching_threshold
:
0.2
max_iou_distance
:
0.9
ppdet/modeling/mot/tracker/jde_tracker.py
浏览文件 @
822c7933
...
...
@@ -44,7 +44,7 @@ class JDETracker(object):
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <0 means no need to filter bboxes,usually set
bad results. If set <
=
0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
...
...
@@ -70,8 +70,8 @@ class JDETracker(object):
num_classes
=
1
,
det_thresh
=
0.3
,
track_buffer
=
30
,
min_box_area
=
20
0
,
vertical_ratio
=
1.6
,
min_box_area
=
0
,
vertical_ratio
=
0
,
tracked_thresh
=
0.7
,
r_tracked_thresh
=
0.5
,
unconfirmed_thresh
=
0.7
,
...
...
@@ -167,9 +167,8 @@ class JDETracker(object):
detections
=
[
STrack
(
STrack
.
tlbr_to_tlwh
(
tlbrs
[
2
:
6
]),
tlbrs
[
1
],
cls_id
,
30
,
temp_feat
)
for
(
tlbrs
,
temp_feat
)
in
zip
(
pred_dets_cls
,
pred_embs_cls
)
30
,
temp_feat
)
for
(
tlbrs
,
temp_feat
)
in
zip
(
pred_dets_cls
,
pred_embs_cls
)
]
else
:
detections
=
[]
...
...
@@ -244,15 +243,13 @@ class JDETracker(object):
for
tlbrs
in
pred_dets_cls_second
]
else
:
pred_embs_cls_second
=
pred_embs_dict
[
cls_id
][
inds_second
]
pred_embs_cls_second
=
pred_embs_dict
[
cls_id
][
inds_second
]
detections_second
=
[
STrack
(
STrack
.
tlbr_to_tlwh
(
tlbrs
[
2
:
6
]),
tlbrs
[
1
],
cls_id
,
30
,
temp_feat
)
for
(
tlbrs
,
temp_feat
)
in
zip
(
pred_dets_cls_second
,
pred_embs_cls_second
)
STrack
.
tlbr_to_tlwh
(
tlbrs
[
2
:
6
]),
tlbrs
[
1
],
cls_id
,
30
,
temp_feat
)
for
(
tlbrs
,
temp_feat
)
in
zip
(
pred_dets_cls_second
,
pred_embs_cls_second
)
]
else
:
detections_second
=
[]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录