Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
39ff9f2f
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39ff9f2f
编写于
6月 30, 2021
作者:
W
wangguanzhong
提交者:
GitHub
6月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix score threshold in mot_infer (#3444)
上级
1264fde9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
25 addition
and
10 deletion
+25
-10
deploy/python/mot_infer.py
deploy/python/mot_infer.py
+6
-4
ppdet/engine/tracker.py
ppdet/engine/tracker.py
+12
-5
tools/infer_mot.py
tools/infer_mot.py
+7
-1
未找到文件。
deploy/python/mot_infer.py
浏览文件 @
39ff9f2f
...
...
@@ -93,7 +93,7 @@ class MOT_Detector(object):
inputs
=
create_inputs
(
im
,
im_info
)
return
inputs
def
postprocess
(
self
,
pred_dets
,
pred_embs
):
def
postprocess
(
self
,
pred_dets
,
pred_embs
,
threshold
):
online_targets
=
self
.
tracker
.
update
(
pred_dets
,
pred_embs
)
online_tlwhs
,
online_ids
=
[],
[]
online_scores
=
[]
...
...
@@ -101,6 +101,7 @@ class MOT_Detector(object):
tlwh
=
t
.
tlwh
tid
=
t
.
track_id
tscore
=
t
.
score
if
tscore
<
threshold
:
continue
vertical
=
tlwh
[
2
]
/
tlwh
[
3
]
>
1.6
if
tlwh
[
2
]
*
tlwh
[
3
]
>
self
.
tracker
.
min_box_area
and
not
vertical
:
online_tlwhs
.
append
(
tlwh
)
...
...
@@ -137,8 +138,8 @@ class MOT_Detector(object):
self
.
det_times
.
inference_time_s
.
end
(
repeats
=
repeats
)
self
.
det_times
.
postprocess_time_s
.
start
()
online_tlwhs
,
online_scores
,
online_ids
=
self
.
postprocess
(
pred_dets
,
pred_embs
)
online_tlwhs
,
online_scores
,
online_ids
=
self
.
postprocess
(
pred_dets
,
pred_embs
,
threshold
)
self
.
det_times
.
postprocess_time_s
.
end
()
self
.
det_times
.
img_num
+=
1
return
online_tlwhs
,
online_scores
,
online_ids
...
...
@@ -363,7 +364,8 @@ def predict_video(detector, camera_id):
online_ids
,
online_scores
,
frame_id
=
frame_id
,
fps
=
fps
)
fps
=
fps
,
threhold
=
FLAGS
.
threshold
)
if
FLAGS
.
save_images
:
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
if
not
os
.
path
.
exists
(
save_dir
):
...
...
ppdet/engine/tracker.py
浏览文件 @
39ff9f2f
...
...
@@ -112,7 +112,8 @@ class Tracker(object):
dataloader
,
save_dir
=
None
,
show_image
=
False
,
frame_rate
=
30
):
frame_rate
=
30
,
draw_threshold
=
0
):
if
save_dir
:
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
tracker
=
self
.
model
.
tracker
...
...
@@ -140,6 +141,7 @@ class Tracker(object):
tlwh
=
t
.
tlwh
tid
=
t
.
track_id
tscore
=
t
.
score
if
tscore
<
draw_threshold
:
continue
vertical
=
tlwh
[
2
]
/
tlwh
[
3
]
>
1.6
if
tlwh
[
2
]
*
tlwh
[
3
]
>
tracker
.
min_box_area
and
not
vertical
:
online_tlwhs
.
append
(
tlwh
)
...
...
@@ -162,7 +164,8 @@ class Tracker(object):
save_dir
=
None
,
show_image
=
False
,
frame_rate
=
30
,
det_file
=
''
):
det_file
=
''
,
draw_threshold
=
0
):
if
save_dir
:
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
tracker
=
self
.
model
.
tracker
...
...
@@ -191,6 +194,7 @@ class Tracker(object):
dets
=
dets_list
[
frame_id
]
bbox_tlwh
=
paddle
.
to_tensor
(
dets
[
'bbox'
],
dtype
=
'float32'
)
pred_scores
=
paddle
.
to_tensor
(
dets
[
'score'
],
dtype
=
'float32'
)
if
pred_scores
<
draw_threshold
:
continue
if
bbox_tlwh
.
shape
[
0
]
>
0
:
pred_bboxes
=
paddle
.
concat
(
(
bbox_tlwh
[:,
0
:
2
],
...
...
@@ -343,7 +347,8 @@ class Tracker(object):
save_images
=
False
,
save_videos
=
True
,
show_image
=
False
,
det_results_dir
=
''
):
det_results_dir
=
''
,
draw_threshold
=
0.5
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
result_root
=
os
.
path
.
join
(
output_dir
,
'mot_results'
)
if
not
os
.
path
.
exists
(
result_root
):
os
.
makedirs
(
result_root
)
...
...
@@ -369,7 +374,8 @@ class Tracker(object):
dataloader
,
save_dir
=
save_dir
,
show_image
=
show_image
,
frame_rate
=
frame_rate
)
frame_rate
=
frame_rate
,
draw_threshold
=
draw_threshold
)
elif
model_type
in
[
'DeepSORT'
]:
results
,
nf
,
ta
,
tc
=
self
.
_eval_seq_sde
(
dataloader
,
...
...
@@ -377,7 +383,8 @@ class Tracker(object):
show_image
=
show_image
,
frame_rate
=
frame_rate
,
det_file
=
os
.
path
.
join
(
det_results_dir
,
'{}.txt'
.
format
(
seq
)))
'{}.txt'
.
format
(
seq
)),
draw_threshold
=
draw_threshold
)
else
:
raise
ValueError
(
model_type
)
...
...
tools/infer_mot.py
浏览文件 @
39ff9f2f
...
...
@@ -68,6 +68,11 @@ def parse_args():
'--show_image'
,
action
=
'store_true'
,
help
=
'Show tracking results (image).'
)
parser
.
add_argument
(
"--draw_threshold"
,
type
=
float
,
default
=
0.5
,
help
=
"Threshold to reserve the result for visualization."
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -94,7 +99,8 @@ def run(FLAGS, cfg):
save_images
=
FLAGS
.
save_images
,
save_videos
=
FLAGS
.
save_videos
,
show_image
=
FLAGS
.
show_image
,
det_results_dir
=
FLAGS
.
det_results_dir
)
det_results_dir
=
FLAGS
.
det_results_dir
,
draw_threshold
=
FLAGS
.
draw_threshold
)
def
main
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录