Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
d53f6412
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d53f6412
编写于
12月 16, 2021
作者:
C
chenjian
提交者:
GitHub
12月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix mot version imcompatible (#1709)
上级
9c8d8959
变更
31
隐藏空白更改
内联
并排
Showing
31 changed file
with
3572 addition
and
14 deletion
+3572
-14
modules/video/multiple_object_tracking/fairmot_dla34/config/_base_/fairmot_dla34.yml
...ct_tracking/fairmot_dla34/config/_base_/fairmot_dla34.yml
+1
-1
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/__init__.py
...le_object_tracking/fairmot_dla34/modeling/mot/__init__.py
+25
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/__init__.py
..._tracking/fairmot_dla34/modeling/mot/matching/__init__.py
+19
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/deepsort_matching.py
.../fairmot_dla34/modeling/mot/matching/deepsort_matching.py
+368
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/jde_matching.py
...cking/fairmot_dla34/modeling/mot/matching/jde_matching.py
+123
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/__init__.py
...ct_tracking/fairmot_dla34/modeling/mot/motion/__init__.py
+17
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/kalman_filter.py
...acking/fairmot_dla34/modeling/mot/motion/kalman_filter.py
+237
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/__init__.py
...t_tracking/fairmot_dla34/modeling/mot/tracker/__init__.py
+21
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_jde_tracker.py
...ng/fairmot_dla34/modeling/mot/tracker/base_jde_tracker.py
+257
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_sde_tracker.py
...ng/fairmot_dla34/modeling/mot/tracker/base_sde_tracker.py
+133
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/jde_tracker.py
...racking/fairmot_dla34/modeling/mot/tracker/jde_tracker.py
+248
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/utils.py
...tiple_object_tracking/fairmot_dla34/modeling/mot/utils.py
+176
-0
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/visualization.py
...ject_tracking/fairmot_dla34/modeling/mot/visualization.py
+117
-0
modules/video/multiple_object_tracking/fairmot_dla34/tracker.py
...s/video/multiple_object_tracking/fairmot_dla34/tracker.py
+4
-5
modules/video/multiple_object_tracking/fairmot_dla34/utils.py
...les/video/multiple_object_tracking/fairmot_dla34/utils.py
+39
-0
modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml
...ct_tracking/jde_darknet53/config/_base_/jde_darknet53.yml
+1
-1
modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml
...cking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml
+1
-1
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/__init__.py
...le_object_tracking/jde_darknet53/modeling/mot/__init__.py
+25
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/__init__.py
..._tracking/jde_darknet53/modeling/mot/matching/__init__.py
+19
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/deepsort_matching.py
.../jde_darknet53/modeling/mot/matching/deepsort_matching.py
+368
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/jde_matching.py
...cking/jde_darknet53/modeling/mot/matching/jde_matching.py
+123
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/__init__.py
...ct_tracking/jde_darknet53/modeling/mot/motion/__init__.py
+17
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/kalman_filter.py
...acking/jde_darknet53/modeling/mot/motion/kalman_filter.py
+237
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/__init__.py
...t_tracking/jde_darknet53/modeling/mot/tracker/__init__.py
+21
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_jde_tracker.py
...ng/jde_darknet53/modeling/mot/tracker/base_jde_tracker.py
+257
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_sde_tracker.py
...ng/jde_darknet53/modeling/mot/tracker/base_sde_tracker.py
+133
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/jde_tracker.py
...racking/jde_darknet53/modeling/mot/tracker/jde_tracker.py
+248
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/utils.py
...tiple_object_tracking/jde_darknet53/modeling/mot/utils.py
+176
-0
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/visualization.py
...ject_tracking/jde_darknet53/modeling/mot/visualization.py
+117
-0
modules/video/multiple_object_tracking/jde_darknet53/tracker.py
...s/video/multiple_object_tracking/jde_darknet53/tracker.py
+5
-6
modules/video/multiple_object_tracking/jde_darknet53/utils.py
...les/video/multiple_object_tracking/jde_darknet53/utils.py
+39
-0
未找到文件。
modules/video/multiple_object_tracking/fairmot_dla34/config/_base_/fairmot_dla34.yml
浏览文件 @
d53f6412
...
...
@@ -5,7 +5,7 @@ FairMOT:
detector
:
CenterNet
reid
:
FairMOTEmbeddingHead
loss
:
FairMOTLoss
tracker
:
JDETracker
tracker
:
Frozen
JDETracker
CenterNet
:
backbone
:
DLA
...
...
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/__init__.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
matching
from
.
import
tracker
from
.
import
motion
from
.
import
visualization
from
.
import
utils
from
.matching
import
*
from
.tracker
import
*
from
.motion
import
*
from
.visualization
import
*
from
.utils
import
*
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/__init__.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
jde_matching
from
.
import
deepsort_matching
from
.jde_matching
import
*
from
.deepsort_matching
import
*
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/deepsort_matching.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/tree/master/deep_sort
"""
import
numpy
as
np
from
scipy.optimize
import
linear_sum_assignment
from
..motion
import
kalman_filter
INFTY_COST
=
1e+5
__all__
=
[
'iou_1toN'
,
'iou_cost'
,
'_nn_euclidean_distance'
,
'_nn_cosine_distance'
,
'NearestNeighborDistanceMetric'
,
'min_cost_matching'
,
'matching_cascade'
,
'gate_cost_matrix'
,
]
def
iou_1toN
(
bbox
,
candidates
):
"""
Computer intersection over union (IoU) by one box to N candidates.
Args:
bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`.
candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the
same format as `bbox`.
Returns:
ious (ndarray): The intersection over union in [0, 1] between the `bbox`
and each candidate. A higher score means a larger fraction of the
`bbox` is occluded by the candidate.
"""
bbox_tl
=
bbox
[:
2
]
bbox_br
=
bbox
[:
2
]
+
bbox
[
2
:]
candidates_tl
=
candidates
[:,
:
2
]
candidates_br
=
candidates
[:,
:
2
]
+
candidates
[:,
2
:]
tl
=
np
.
c_
[
np
.
maximum
(
bbox_tl
[
0
],
candidates_tl
[:,
0
])[:,
np
.
newaxis
],
np
.
maximum
(
bbox_tl
[
1
],
candidates_tl
[:,
1
])[:,
np
.
newaxis
]]
br
=
np
.
c_
[
np
.
minimum
(
bbox_br
[
0
],
candidates_br
[:,
0
])[:,
np
.
newaxis
],
np
.
minimum
(
bbox_br
[
1
],
candidates_br
[:,
1
])[:,
np
.
newaxis
]]
wh
=
np
.
maximum
(
0.
,
br
-
tl
)
area_intersection
=
wh
.
prod
(
axis
=
1
)
area_bbox
=
bbox
[
2
:].
prod
()
area_candidates
=
candidates
[:,
2
:].
prod
(
axis
=
1
)
ious
=
area_intersection
/
(
area_bbox
+
area_candidates
-
area_intersection
)
return
ious
def
iou_cost
(
tracks
,
detections
,
track_indices
=
None
,
detection_indices
=
None
):
"""
IoU distance metric.
Args:
tracks (list[Track]): A list of tracks.
detections (list[Detection]): A list of detections.
track_indices (Optional[list[int]]): A list of indices to tracks that
should be matched. Defaults to all `tracks`.
detection_indices (Optional[list[int]]): A list of indices to detections
that should be matched. Defaults to all `detections`.
Returns:
cost_matrix (ndarray): A cost matrix of shape len(track_indices),
len(detection_indices) where entry (i, j) is
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
"""
if
track_indices
is
None
:
track_indices
=
np
.
arange
(
len
(
tracks
))
if
detection_indices
is
None
:
detection_indices
=
np
.
arange
(
len
(
detections
))
cost_matrix
=
np
.
zeros
((
len
(
track_indices
),
len
(
detection_indices
)))
for
row
,
track_idx
in
enumerate
(
track_indices
):
if
tracks
[
track_idx
].
time_since_update
>
1
:
cost_matrix
[
row
,
:]
=
1e+5
continue
bbox
=
tracks
[
track_idx
].
to_tlwh
()
candidates
=
np
.
asarray
([
detections
[
i
].
tlwh
for
i
in
detection_indices
])
cost_matrix
[
row
,
:]
=
1.
-
iou_1toN
(
bbox
,
candidates
)
return
cost_matrix
def
_nn_euclidean_distance
(
s
,
q
):
"""
Compute pair-wise squared (Euclidean) distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s
,
q
=
np
.
asarray
(
s
),
np
.
asarray
(
q
)
if
len
(
s
)
==
0
or
len
(
q
)
==
0
:
return
np
.
zeros
((
len
(
s
),
len
(
q
)))
s2
,
q2
=
np
.
square
(
s
).
sum
(
axis
=
1
),
np
.
square
(
q
).
sum
(
axis
=
1
)
distances
=
-
2.
*
np
.
dot
(
s
,
q
.
T
)
+
s2
[:,
None
]
+
q2
[
None
,
:]
distances
=
np
.
clip
(
distances
,
0.
,
float
(
np
.
inf
))
return
np
.
maximum
(
0.0
,
distances
.
min
(
axis
=
0
))
def
_nn_cosine_distance
(
s
,
q
):
"""
Compute pair-wise cosine distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s
=
np
.
asarray
(
s
)
/
np
.
linalg
.
norm
(
s
,
axis
=
1
,
keepdims
=
True
)
q
=
np
.
asarray
(
q
)
/
np
.
linalg
.
norm
(
q
,
axis
=
1
,
keepdims
=
True
)
distances
=
1.
-
np
.
dot
(
s
,
q
.
T
)
return
distances
.
min
(
axis
=
0
)
class
NearestNeighborDistanceMetric
(
object
):
"""
A nearest neighbor distance metric that, for each target, returns
the closest distance to any sample that has been observed so far.
Args:
metric (str): Either "euclidean" or "cosine".
matching_threshold (float): The matching threshold. Samples with larger
distance are considered an invalid match.
budget (Optional[int]): If not None, fix samples per class to at most
this number. Removes the oldest samples when the budget is reached.
Attributes:
samples (Dict[int -> List[ndarray]]): A dictionary that maps from target
identities to the list of samples that have been observed so far.
"""
def
__init__
(
self
,
metric
,
matching_threshold
,
budget
=
None
):
if
metric
==
"euclidean"
:
self
.
_metric
=
_nn_euclidean_distance
elif
metric
==
"cosine"
:
self
.
_metric
=
_nn_cosine_distance
else
:
raise
ValueError
(
"Invalid metric; must be either 'euclidean' or 'cosine'"
)
self
.
matching_threshold
=
matching_threshold
self
.
budget
=
budget
self
.
samples
=
{}
def
partial_fit
(
self
,
features
,
targets
,
active_targets
):
"""
Update the distance metric with new data.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (ndarray): An integer array of associated target identities.
active_targets (List[int]): A list of targets that are currently
present in the scene.
"""
for
feature
,
target
in
zip
(
features
,
targets
):
self
.
samples
.
setdefault
(
target
,
[]).
append
(
feature
)
if
self
.
budget
is
not
None
:
self
.
samples
[
target
]
=
self
.
samples
[
target
][
-
self
.
budget
:]
self
.
samples
=
{
k
:
self
.
samples
[
k
]
for
k
in
active_targets
}
def
distance
(
self
,
features
,
targets
):
"""
Compute distance between features and targets.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (list[int]): A list of targets to match the given `features` against.
Returns:
cost_matrix (ndarray): a cost matrix of shape len(targets), len(features),
where element (i, j) contains the closest squared distance between
`targets[i]` and `features[j]`.
"""
cost_matrix
=
np
.
zeros
((
len
(
targets
),
len
(
features
)))
for
i
,
target
in
enumerate
(
targets
):
cost_matrix
[
i
,
:]
=
self
.
_metric
(
self
.
samples
[
target
],
features
)
return
cost_matrix
def
min_cost_matching
(
distance_metric
,
max_distance
,
tracks
,
detections
,
track_indices
=
None
,
detection_indices
=
None
):
"""
Solve linear assignment problem.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if
track_indices
is
None
:
track_indices
=
np
.
arange
(
len
(
tracks
))
if
detection_indices
is
None
:
detection_indices
=
np
.
arange
(
len
(
detections
))
if
len
(
detection_indices
)
==
0
or
len
(
track_indices
)
==
0
:
return
[],
track_indices
,
detection_indices
# Nothing to match.
cost_matrix
=
distance_metric
(
tracks
,
detections
,
track_indices
,
detection_indices
)
cost_matrix
[
cost_matrix
>
max_distance
]
=
max_distance
+
1e-5
indices
=
linear_sum_assignment
(
cost_matrix
)
matches
,
unmatched_tracks
,
unmatched_detections
=
[],
[],
[]
for
col
,
detection_idx
in
enumerate
(
detection_indices
):
if
col
not
in
indices
[
1
]:
unmatched_detections
.
append
(
detection_idx
)
for
row
,
track_idx
in
enumerate
(
track_indices
):
if
row
not
in
indices
[
0
]:
unmatched_tracks
.
append
(
track_idx
)
for
row
,
col
in
zip
(
indices
[
0
],
indices
[
1
]):
track_idx
=
track_indices
[
row
]
detection_idx
=
detection_indices
[
col
]
if
cost_matrix
[
row
,
col
]
>
max_distance
:
unmatched_tracks
.
append
(
track_idx
)
unmatched_detections
.
append
(
detection_idx
)
else
:
matches
.
append
((
track_idx
,
detection_idx
))
return
matches
,
unmatched_tracks
,
unmatched_detections
def
matching_cascade
(
distance_metric
,
max_distance
,
cascade_depth
,
tracks
,
detections
,
track_indices
=
None
,
detection_indices
=
None
):
"""
Run matching cascade.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
cascade_depth (int): The cascade depth, should be se to the maximum
track age.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if
track_indices
is
None
:
track_indices
=
list
(
range
(
len
(
tracks
)))
if
detection_indices
is
None
:
detection_indices
=
list
(
range
(
len
(
detections
)))
unmatched_detections
=
detection_indices
matches
=
[]
for
level
in
range
(
cascade_depth
):
if
len
(
unmatched_detections
)
==
0
:
# No detections left
break
track_indices_l
=
[
k
for
k
in
track_indices
if
tracks
[
k
].
time_since_update
==
1
+
level
]
if
len
(
track_indices_l
)
==
0
:
# Nothing to match at this level
continue
matches_l
,
_
,
unmatched_detections
=
\
min_cost_matching
(
distance_metric
,
max_distance
,
tracks
,
detections
,
track_indices_l
,
unmatched_detections
)
matches
+=
matches_l
unmatched_tracks
=
list
(
set
(
track_indices
)
-
set
(
k
for
k
,
_
in
matches
))
return
matches
,
unmatched_tracks
,
unmatched_detections
def
gate_cost_matrix
(
kf
,
cost_matrix
,
tracks
,
detections
,
track_indices
,
detection_indices
,
gated_cost
=
INFTY_COST
,
only_position
=
False
):
"""
Invalidate infeasible entries in cost matrix based on the state
distributions obtained by Kalman filtering.
Args:
kf (object): The Kalman filter.
cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the
number of track indices and M is the number of detection indices,
such that entry (i, j) is the association cost between
`tracks[track_indices[i]]` and `detections[detection_indices[j]]`.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (List[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
gated_cost (Optional[float]): Entries in the cost matrix corresponding
to infeasible associations are set this value. Defaults to a very
large value.
only_position (Optional[bool]): If True, only the x, y position of the
state distribution is considered during gating. Default False.
"""
gating_dim
=
2
if
only_position
else
4
gating_threshold
=
kalman_filter
.
chi2inv95
[
gating_dim
]
measurements
=
np
.
asarray
([
detections
[
i
].
to_xyah
()
for
i
in
detection_indices
])
for
row
,
track_idx
in
enumerate
(
track_indices
):
track
=
tracks
[
track_idx
]
gating_distance
=
kf
.
gating_distance
(
track
.
mean
,
track
.
covariance
,
measurements
,
only_position
)
cost_matrix
[
row
,
gating_distance
>
gating_threshold
]
=
gated_cost
return
cost_matrix
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/jde_matching.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py
"""
import
lap
import
scipy
import
numpy
as
np
from
scipy.spatial.distance
import
cdist
from
..motion
import
kalman_filter
from
ppdet.utils.logger
import
setup_logger
logger
=
setup_logger
(
__name__
)
__all__
=
[
'merge_matches'
,
'linear_assignment'
,
'cython_bbox_ious'
,
'iou_distance'
,
'embedding_distance'
,
'fuse_motion'
,
]
def
merge_matches
(
m1
,
m2
,
shape
):
O
,
P
,
Q
=
shape
m1
=
np
.
asarray
(
m1
)
m2
=
np
.
asarray
(
m2
)
M1
=
scipy
.
sparse
.
coo_matrix
((
np
.
ones
(
len
(
m1
)),
(
m1
[:,
0
],
m1
[:,
1
])),
shape
=
(
O
,
P
))
M2
=
scipy
.
sparse
.
coo_matrix
((
np
.
ones
(
len
(
m2
)),
(
m2
[:,
0
],
m2
[:,
1
])),
shape
=
(
P
,
Q
))
mask
=
M1
*
M2
match
=
mask
.
nonzero
()
match
=
list
(
zip
(
match
[
0
],
match
[
1
]))
unmatched_O
=
tuple
(
set
(
range
(
O
))
-
set
([
i
for
i
,
j
in
match
]))
unmatched_Q
=
tuple
(
set
(
range
(
Q
))
-
set
([
j
for
i
,
j
in
match
]))
return
match
,
unmatched_O
,
unmatched_Q
def
linear_assignment
(
cost_matrix
,
thresh
):
if
cost_matrix
.
size
==
0
:
return
np
.
empty
((
0
,
2
),
dtype
=
int
),
tuple
(
range
(
cost_matrix
.
shape
[
0
])),
tuple
(
range
(
cost_matrix
.
shape
[
1
]))
matches
,
unmatched_a
,
unmatched_b
=
[],
[],
[]
cost
,
x
,
y
=
lap
.
lapjv
(
cost_matrix
,
extend_cost
=
True
,
cost_limit
=
thresh
)
for
ix
,
mx
in
enumerate
(
x
):
if
mx
>=
0
:
matches
.
append
([
ix
,
mx
])
unmatched_a
=
np
.
where
(
x
<
0
)[
0
]
unmatched_b
=
np
.
where
(
y
<
0
)[
0
]
matches
=
np
.
asarray
(
matches
)
return
matches
,
unmatched_a
,
unmatched_b
def
cython_bbox_ious
(
atlbrs
,
btlbrs
):
ious
=
np
.
zeros
((
len
(
atlbrs
),
len
(
btlbrs
)),
dtype
=
np
.
float
)
if
ious
.
size
==
0
:
return
ious
try
:
import
cython_bbox
except
Exception
as
e
:
logger
.
error
(
'cython_bbox not found, please install cython_bbox.'
'for example: `pip install cython_bbox`.'
)
raise
e
ious
=
cython_bbox
.
bbox_overlaps
(
np
.
ascontiguousarray
(
atlbrs
,
dtype
=
np
.
float
),
np
.
ascontiguousarray
(
btlbrs
,
dtype
=
np
.
float
))
return
ious
def
iou_distance
(
atracks
,
btracks
):
"""
Compute cost based on IoU between two list[STrack].
"""
if
(
len
(
atracks
)
>
0
and
isinstance
(
atracks
[
0
],
np
.
ndarray
))
or
(
len
(
btracks
)
>
0
and
isinstance
(
btracks
[
0
],
np
.
ndarray
)):
atlbrs
=
atracks
btlbrs
=
btracks
else
:
atlbrs
=
[
track
.
tlbr
for
track
in
atracks
]
btlbrs
=
[
track
.
tlbr
for
track
in
btracks
]
_ious
=
cython_bbox_ious
(
atlbrs
,
btlbrs
)
cost_matrix
=
1
-
_ious
return
cost_matrix
def
embedding_distance
(
tracks
,
detections
,
metric
=
'euclidean'
):
"""
Compute cost based on features between two list[STrack].
"""
cost_matrix
=
np
.
zeros
((
len
(
tracks
),
len
(
detections
)),
dtype
=
np
.
float
)
if
cost_matrix
.
size
==
0
:
return
cost_matrix
det_features
=
np
.
asarray
([
track
.
curr_feat
for
track
in
detections
],
dtype
=
np
.
float
)
track_features
=
np
.
asarray
([
track
.
smooth_feat
for
track
in
tracks
],
dtype
=
np
.
float
)
cost_matrix
=
np
.
maximum
(
0.0
,
cdist
(
track_features
,
det_features
,
metric
))
# Nomalized features
return
cost_matrix
def
fuse_motion
(
kf
,
cost_matrix
,
tracks
,
detections
,
only_position
=
False
,
lambda_
=
0.98
):
if
cost_matrix
.
size
==
0
:
return
cost_matrix
gating_dim
=
2
if
only_position
else
4
gating_threshold
=
kalman_filter
.
chi2inv95
[
gating_dim
]
measurements
=
np
.
asarray
([
det
.
to_xyah
()
for
det
in
detections
])
for
row
,
track
in
enumerate
(
tracks
):
gating_distance
=
kf
.
gating_distance
(
track
.
mean
,
track
.
covariance
,
measurements
,
only_position
,
metric
=
'maha'
)
cost_matrix
[
row
,
gating_distance
>
gating_threshold
]
=
np
.
inf
cost_matrix
[
row
]
=
lambda_
*
cost_matrix
[
row
]
+
(
1
-
lambda_
)
*
gating_distance
return
cost_matrix
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/__init__.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
kalman_filter
from
.kalman_filter
import
*
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/kalman_filter.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/kalman_filter.py
"""
import
numpy
as
np
import
scipy.linalg
__all__
=
[
'KalmanFilter'
]
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
function and used as Mahalanobis gating threshold.
"""
chi2inv95
=
{
1
:
3.8415
,
2
:
5.9915
,
3
:
7.8147
,
4
:
9.4877
,
5
:
11.070
,
6
:
12.592
,
7
:
14.067
,
8
:
15.507
,
9
:
16.919
}
class
KalmanFilter
(
object
):
"""
A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, a, h, vx, vy, va, vh
contains the bounding box center position (x, y), aspect ratio a, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, a, h) is taken as direct observation of the state space (linear
observation model).
"""
def
__init__
(
self
):
ndim
,
dt
=
4
,
1.
# Create Kalman filter model matrices.
self
.
_motion_mat
=
np
.
eye
(
2
*
ndim
,
2
*
ndim
)
for
i
in
range
(
ndim
):
self
.
_motion_mat
[
i
,
ndim
+
i
]
=
dt
self
.
_update_mat
=
np
.
eye
(
ndim
,
2
*
ndim
)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
self
.
_std_weight_position
=
1.
/
20
self
.
_std_weight_velocity
=
1.
/
160
def
initiate
(
self
,
measurement
):
"""
Create track from unassociated measurement.
Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with
center position (x, y), aspect ratio a, and height h.
Returns:
The mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are
initialized to 0 mean.
"""
mean_pos
=
measurement
mean_vel
=
np
.
zeros_like
(
mean_pos
)
mean
=
np
.
r_
[
mean_pos
,
mean_vel
]
std
=
[
2
*
self
.
_std_weight_position
*
measurement
[
3
],
2
*
self
.
_std_weight_position
*
measurement
[
3
],
1e-2
,
2
*
self
.
_std_weight_position
*
measurement
[
3
],
10
*
self
.
_std_weight_velocity
*
measurement
[
3
],
10
*
self
.
_std_weight_velocity
*
measurement
[
3
],
1e-5
,
10
*
self
.
_std_weight_velocity
*
measurement
[
3
]
]
covariance
=
np
.
diag
(
np
.
square
(
std
))
return
mean
,
covariance
def
predict
(
self
,
mean
,
covariance
):
"""
Run Kalman filter prediction step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state
at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the
object state at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos
=
[
self
.
_std_weight_position
*
mean
[
3
],
self
.
_std_weight_position
*
mean
[
3
],
1e-2
,
self
.
_std_weight_position
*
mean
[
3
]
]
std_vel
=
[
self
.
_std_weight_velocity
*
mean
[
3
],
self
.
_std_weight_velocity
*
mean
[
3
],
1e-5
,
self
.
_std_weight_velocity
*
mean
[
3
]
]
motion_cov
=
np
.
diag
(
np
.
square
(
np
.
r_
[
std_pos
,
std_vel
]))
#mean = np.dot(self._motion_mat, mean)
mean
=
np
.
dot
(
mean
,
self
.
_motion_mat
.
T
)
covariance
=
np
.
linalg
.
multi_dot
((
self
.
_motion_mat
,
covariance
,
self
.
_motion_mat
.
T
))
+
motion_cov
return
mean
,
covariance
def
project
(
self
,
mean
,
covariance
):
"""
Project state distribution to measurement space.
Args
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns:
The projected mean and covariance matrix of the given state estimate.
"""
std
=
[
self
.
_std_weight_position
*
mean
[
3
],
self
.
_std_weight_position
*
mean
[
3
],
1e-1
,
self
.
_std_weight_position
*
mean
[
3
]
]
innovation_cov
=
np
.
diag
(
np
.
square
(
std
))
mean
=
np
.
dot
(
self
.
_update_mat
,
mean
)
covariance
=
np
.
linalg
.
multi_dot
((
self
.
_update_mat
,
covariance
,
self
.
_update_mat
.
T
))
return
mean
,
covariance
+
innovation_cov
def
multi_predict
(
self
,
mean
,
covariance
):
"""
Run Kalman filter prediction step (Vectorized version).
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states
at the previous time step.
covariance (ndarray): The Nx8x8 dimensional covariance matrics of the
object states at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos
=
[
self
.
_std_weight_position
*
mean
[:,
3
],
self
.
_std_weight_position
*
mean
[:,
3
],
1e-2
*
np
.
ones_like
(
mean
[:,
3
]),
self
.
_std_weight_position
*
mean
[:,
3
]
]
std_vel
=
[
self
.
_std_weight_velocity
*
mean
[:,
3
],
self
.
_std_weight_velocity
*
mean
[:,
3
],
1e-5
*
np
.
ones_like
(
mean
[:,
3
]),
self
.
_std_weight_velocity
*
mean
[:,
3
]
]
sqr
=
np
.
square
(
np
.
r_
[
std_pos
,
std_vel
]).
T
motion_cov
=
[]
for
i
in
range
(
len
(
mean
)):
motion_cov
.
append
(
np
.
diag
(
sqr
[
i
]))
motion_cov
=
np
.
asarray
(
motion_cov
)
mean
=
np
.
dot
(
mean
,
self
.
_motion_mat
.
T
)
left
=
np
.
dot
(
self
.
_motion_mat
,
covariance
).
transpose
((
1
,
0
,
2
))
covariance
=
np
.
dot
(
left
,
self
.
_motion_mat
.
T
)
+
motion_cov
return
mean
,
covariance
def
update
(
self
,
mean
,
covariance
,
measurement
):
"""
Run Kalman filter correction step.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector
(x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Returns:
The measurement-corrected state distribution.
"""
projected_mean
,
projected_cov
=
self
.
project
(
mean
,
covariance
)
chol_factor
,
lower
=
scipy
.
linalg
.
cho_factor
(
projected_cov
,
lower
=
True
,
check_finite
=
False
)
kalman_gain
=
scipy
.
linalg
.
cho_solve
((
chol_factor
,
lower
),
np
.
dot
(
covariance
,
self
.
_update_mat
.
T
).
T
,
check_finite
=
False
).
T
innovation
=
measurement
-
projected_mean
new_mean
=
mean
+
np
.
dot
(
innovation
,
kalman_gain
.
T
)
new_covariance
=
covariance
-
np
.
linalg
.
multi_dot
((
kalman_gain
,
projected_cov
,
kalman_gain
.
T
))
return
new_mean
,
new_covariance
def
gating_distance
(
self
,
mean
,
covariance
,
measurements
,
only_position
=
False
,
metric
=
'maha'
):
"""
Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Args:
mean (ndarray): Mean vector over the state distribution (8
dimensional).
covariance (ndarray): Covariance of the state distribution (8x8
dimensional).
measurements (ndarray): An Nx4 dimensional matrix of N measurements,
each in format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
only_position (Optional[bool]): If True, distance computation is
done with respect to the bounding box center position only.
metric (str): Metric type, 'gaussian' or 'maha'.
Returns
An array of length N, where the i-th element contains the squared
Mahalanobis distance between (mean, covariance) and `measurements[i]`.
"""
mean
,
covariance
=
self
.
project
(
mean
,
covariance
)
if
only_position
:
mean
,
covariance
=
mean
[:
2
],
covariance
[:
2
,
:
2
]
measurements
=
measurements
[:,
:
2
]
d
=
measurements
-
mean
if
metric
==
'gaussian'
:
return
np
.
sum
(
d
*
d
,
axis
=
1
)
elif
metric
==
'maha'
:
cholesky_factor
=
np
.
linalg
.
cholesky
(
covariance
)
z
=
scipy
.
linalg
.
solve_triangular
(
cholesky_factor
,
d
.
T
,
lower
=
True
,
check_finite
=
False
,
overwrite_b
=
True
)
squared_maha
=
np
.
sum
(
z
*
z
,
axis
=
0
)
return
squared_maha
else
:
raise
ValueError
(
'invalid distance metric'
)
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/__init__.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
base_jde_tracker
from
.
import
base_sde_tracker
from
.
import
jde_tracker
from
.base_jde_tracker
import
*
from
.base_sde_tracker
import
*
from
.jde_tracker
import
*
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_jde_tracker.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import
numpy
as
np
from
collections
import
deque
,
OrderedDict
from
..matching
import
jde_matching
as
matching
from
ppdet.core.workspace
import
register
,
serializable
__all__
=
[
'TrackState'
,
'BaseTrack'
,
'STrack'
,
'joint_stracks'
,
'sub_stracks'
,
'remove_duplicate_stracks'
,
]
class
TrackState
(
object
):
New
=
0
Tracked
=
1
Lost
=
2
Removed
=
3
class
BaseTrack
(
object
):
_count
=
0
track_id
=
0
is_activated
=
False
state
=
TrackState
.
New
history
=
OrderedDict
()
features
=
[]
curr_feature
=
None
score
=
0
start_frame
=
0
frame_id
=
0
time_since_update
=
0
# multi-camera
location
=
(
np
.
inf
,
np
.
inf
)
@
property
def
end_frame
(
self
):
return
self
.
frame_id
@
staticmethod
def
next_id
():
BaseTrack
.
_count
+=
1
return
BaseTrack
.
_count
def
activate
(
self
,
*
args
):
raise
NotImplementedError
def
predict
(
self
):
raise
NotImplementedError
def
update
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
mark_lost
(
self
):
self
.
state
=
TrackState
.
Lost
def
mark_removed
(
self
):
self
.
state
=
TrackState
.
Removed
class
STrack
(
BaseTrack
):
def
__init__
(
self
,
tlwh
,
score
,
temp_feat
,
buffer_size
=
30
):
# wait activate
self
.
_tlwh
=
np
.
asarray
(
tlwh
,
dtype
=
np
.
float
)
self
.
kalman_filter
=
None
self
.
mean
,
self
.
covariance
=
None
,
None
self
.
is_activated
=
False
self
.
score
=
score
self
.
tracklet_len
=
0
self
.
smooth_feat
=
None
self
.
update_features
(
temp_feat
)
self
.
features
=
deque
([],
maxlen
=
buffer_size
)
self
.
alpha
=
0.9
def
update_features
(
self
,
feat
):
feat
/=
np
.
linalg
.
norm
(
feat
)
self
.
curr_feat
=
feat
if
self
.
smooth_feat
is
None
:
self
.
smooth_feat
=
feat
else
:
self
.
smooth_feat
=
self
.
alpha
*
self
.
smooth_feat
+
(
1
-
self
.
alpha
)
*
feat
self
.
features
.
append
(
feat
)
self
.
smooth_feat
/=
np
.
linalg
.
norm
(
self
.
smooth_feat
)
def
predict
(
self
):
mean_state
=
self
.
mean
.
copy
()
if
self
.
state
!=
TrackState
.
Tracked
:
mean_state
[
7
]
=
0
self
.
mean
,
self
.
covariance
=
self
.
kalman_filter
.
predict
(
mean_state
,
self
.
covariance
)
@
staticmethod
def
multi_predict
(
stracks
,
kalman_filter
):
if
len
(
stracks
)
>
0
:
multi_mean
=
np
.
asarray
([
st
.
mean
.
copy
()
for
st
in
stracks
])
multi_covariance
=
np
.
asarray
([
st
.
covariance
for
st
in
stracks
])
for
i
,
st
in
enumerate
(
stracks
):
if
st
.
state
!=
TrackState
.
Tracked
:
multi_mean
[
i
][
7
]
=
0
multi_mean
,
multi_covariance
=
kalman_filter
.
multi_predict
(
multi_mean
,
multi_covariance
)
for
i
,
(
mean
,
cov
)
in
enumerate
(
zip
(
multi_mean
,
multi_covariance
)):
stracks
[
i
].
mean
=
mean
stracks
[
i
].
covariance
=
cov
def
activate
(
self
,
kalman_filter
,
frame_id
):
"""Start a new tracklet"""
self
.
kalman_filter
=
kalman_filter
self
.
track_id
=
self
.
next_id
()
self
.
mean
,
self
.
covariance
=
self
.
kalman_filter
.
initiate
(
self
.
tlwh_to_xyah
(
self
.
_tlwh
))
self
.
tracklet_len
=
0
self
.
state
=
TrackState
.
Tracked
if
frame_id
==
1
:
self
.
is_activated
=
True
self
.
frame_id
=
frame_id
self
.
start_frame
=
frame_id
def
re_activate
(
self
,
new_track
,
frame_id
,
new_id
=
False
):
self
.
mean
,
self
.
covariance
=
self
.
kalman_filter
.
update
(
self
.
mean
,
self
.
covariance
,
self
.
tlwh_to_xyah
(
new_track
.
tlwh
))
self
.
update_features
(
new_track
.
curr_feat
)
self
.
tracklet_len
=
0
self
.
state
=
TrackState
.
Tracked
self
.
is_activated
=
True
self
.
frame_id
=
frame_id
if
new_id
:
self
.
track_id
=
self
.
next_id
()
def
update
(
self
,
new_track
,
frame_id
,
update_feature
=
True
):
self
.
frame_id
=
frame_id
self
.
tracklet_len
+=
1
new_tlwh
=
new_track
.
tlwh
self
.
mean
,
self
.
covariance
=
self
.
kalman_filter
.
update
(
self
.
mean
,
self
.
covariance
,
self
.
tlwh_to_xyah
(
new_tlwh
))
self
.
state
=
TrackState
.
Tracked
self
.
is_activated
=
True
self
.
score
=
new_track
.
score
if
update_feature
:
self
.
update_features
(
new_track
.
curr_feat
)
@
property
def
tlwh
(
self
):
"""
Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if
self
.
mean
is
None
:
return
self
.
_tlwh
.
copy
()
ret
=
self
.
mean
[:
4
].
copy
()
ret
[
2
]
*=
ret
[
3
]
ret
[:
2
]
-=
ret
[
2
:]
/
2
return
ret
@
property
def
tlbr
(
self
):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret
=
self
.
tlwh
.
copy
()
ret
[
2
:]
+=
ret
[:
2
]
return
ret
@
staticmethod
def
tlwh_to_xyah
(
tlwh
):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret
=
np
.
asarray
(
tlwh
).
copy
()
ret
[:
2
]
+=
ret
[
2
:]
/
2
ret
[
2
]
/=
ret
[
3
]
return
ret
def
to_xyah
(
self
):
return
self
.
tlwh_to_xyah
(
self
.
tlwh
)
@
staticmethod
def
tlbr_to_tlwh
(
tlbr
):
ret
=
np
.
asarray
(
tlbr
).
copy
()
ret
[
2
:]
-=
ret
[:
2
]
return
ret
@
staticmethod
def
tlwh_to_tlbr
(
tlwh
):
ret
=
np
.
asarray
(
tlwh
).
copy
()
ret
[
2
:]
+=
ret
[:
2
]
return
ret
def
__repr__
(
self
):
return
'OT_{}_({}-{})'
.
format
(
self
.
track_id
,
self
.
start_frame
,
self
.
end_frame
)
def
joint_stracks
(
tlista
,
tlistb
):
exists
=
{}
res
=
[]
for
t
in
tlista
:
exists
[
t
.
track_id
]
=
1
res
.
append
(
t
)
for
t
in
tlistb
:
tid
=
t
.
track_id
if
not
exists
.
get
(
tid
,
0
):
exists
[
tid
]
=
1
res
.
append
(
t
)
return
res
def
sub_stracks
(
tlista
,
tlistb
):
stracks
=
{}
for
t
in
tlista
:
stracks
[
t
.
track_id
]
=
t
for
t
in
tlistb
:
tid
=
t
.
track_id
if
stracks
.
get
(
tid
,
0
):
del
stracks
[
tid
]
return
list
(
stracks
.
values
())
def
remove_duplicate_stracks
(
stracksa
,
stracksb
):
pdist
=
matching
.
iou_distance
(
stracksa
,
stracksb
)
pairs
=
np
.
where
(
pdist
<
0.15
)
dupa
,
dupb
=
list
(),
list
()
for
p
,
q
in
zip
(
*
pairs
):
timep
=
stracksa
[
p
].
frame_id
-
stracksa
[
p
].
start_frame
timeq
=
stracksb
[
q
].
frame_id
-
stracksb
[
q
].
start_frame
if
timep
>
timeq
:
dupb
.
append
(
q
)
else
:
dupa
.
append
(
p
)
resa
=
[
t
for
i
,
t
in
enumerate
(
stracksa
)
if
not
i
in
dupa
]
resb
=
[
t
for
i
,
t
in
enumerate
(
stracksb
)
if
not
i
in
dupb
]
return
resa
,
resb
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_sde_tracker.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py
"""
from
ppdet.core.workspace
import
register
,
serializable
__all__
=
[
'TrackState'
,
'Track'
]
class
TrackState
(
object
):
"""
Enumeration type for the single target track state. Newly created tracks are
classified as `tentative` until enough evidence has been collected. Then,
the track state is changed to `confirmed`. Tracks that are no longer alive
are classified as `deleted` to mark them for removal from the set of active
tracks.
"""
Tentative
=
1
Confirmed
=
2
Deleted
=
3
class
Track
(
object
):
"""
A single target track with state space `(x, y, a, h)` and associated
velocities, where `(x, y)` is the center of the bounding box, `a` is the
aspect ratio and `h` is the height.
Args:
mean (ndarray): Mean vector of the initial state distribution.
covariance (ndarray): Covariance matrix of the initial state distribution.
track_id (int): A unique track identifier.
n_init (int): Number of consecutive detections before the track is confirmed.
The track state is set to `Deleted` if a miss occurs within the first
`n_init` frames.
max_age (int): The maximum number of consecutive misses before the track
state is set to `Deleted`.
feature (Optional[ndarray]): Feature vector of the detection this track
originates from. If not None, this feature is added to the `features` cache.
Attributes:
hits (int): Total number of measurement updates.
age (int): Total number of frames since first occurance.
time_since_update (int): Total number of frames since last measurement
update.
state (TrackState): The current track state.
features (List[ndarray]): A cache of features. On each measurement update,
the associated feature vector is added to this list.
"""
def
__init__
(
self
,
mean
,
covariance
,
track_id
,
n_init
,
max_age
,
feature
=
None
):
self
.
mean
=
mean
self
.
covariance
=
covariance
self
.
track_id
=
track_id
self
.
hits
=
1
self
.
age
=
1
self
.
time_since_update
=
0
self
.
state
=
TrackState
.
Tentative
self
.
features
=
[]
if
feature
is
not
None
:
self
.
features
.
append
(
feature
)
self
.
_n_init
=
n_init
self
.
_max_age
=
max_age
def
to_tlwh
(
self
):
"""Get position in format `(top left x, top left y, width, height)`."""
ret
=
self
.
mean
[:
4
].
copy
()
ret
[
2
]
*=
ret
[
3
]
ret
[:
2
]
-=
ret
[
2
:]
/
2
return
ret
def
to_tlbr
(
self
):
"""Get position in bounding box format `(min x, miny, max x, max y)`."""
ret
=
self
.
to_tlwh
()
ret
[
2
:]
=
ret
[:
2
]
+
ret
[
2
:]
return
ret
def
predict
(
self
,
kalman_filter
):
"""
Propagate the state distribution to the current time step using a Kalman
filter prediction step.
"""
self
.
mean
,
self
.
covariance
=
kalman_filter
.
predict
(
self
.
mean
,
self
.
covariance
)
self
.
age
+=
1
self
.
time_since_update
+=
1
def
update
(
self
,
kalman_filter
,
detection
):
"""
Perform Kalman filter measurement update step and update the associated
detection feature cache.
"""
self
.
mean
,
self
.
covariance
=
kalman_filter
.
update
(
self
.
mean
,
self
.
covariance
,
detection
.
to_xyah
())
self
.
features
.
append
(
detection
.
feature
)
self
.
hits
+=
1
self
.
time_since_update
=
0
if
self
.
state
==
TrackState
.
Tentative
and
self
.
hits
>=
self
.
_n_init
:
self
.
state
=
TrackState
.
Confirmed
def
mark_missed
(
self
):
"""Mark this track as missed (no association at the current time step).
"""
if
self
.
state
==
TrackState
.
Tentative
:
self
.
state
=
TrackState
.
Deleted
elif
self
.
time_since_update
>
self
.
_max_age
:
self
.
state
=
TrackState
.
Deleted
def
is_tentative
(
self
):
"""Returns True if this track is tentative (unconfirmed)."""
return
self
.
state
==
TrackState
.
Tentative
def
is_confirmed
(
self
):
"""Returns True if this track is confirmed."""
return
self
.
state
==
TrackState
.
Confirmed
def
is_deleted
(
self
):
"""Returns True if this track is dead and should be deleted."""
return
self
.
state
==
TrackState
.
Deleted
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/jde_tracker.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import
paddle
from
..matching
import
jde_matching
as
matching
from
.base_jde_tracker
import
TrackState
,
BaseTrack
,
STrack
from
.base_jde_tracker
import
joint_stracks
,
sub_stracks
,
remove_duplicate_stracks
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.utils.logger
import
setup_logger
logger
=
setup_logger
(
__name__
)
__all__
=
[
'FrozenJDETracker'
]
@
register
@
serializable
class
FrozenJDETracker
(
object
):
__inject__
=
[
'motion'
]
"""
JDE tracker
Args:
det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set -1
means no need to filter bboxes.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
r_tracked_thresh (float): linear assignment threshold of
tracked stracks and unmatched detections
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
motion (object): KalmanFilter instance
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def
__init__
(
self
,
det_thresh
=
0.3
,
track_buffer
=
30
,
min_box_area
=
200
,
vertical_ratio
=
1.6
,
tracked_thresh
=
0.7
,
r_tracked_thresh
=
0.5
,
unconfirmed_thresh
=
0.7
,
motion
=
'KalmanFilter'
,
conf_thres
=
0
,
metric_type
=
'euclidean'
):
self
.
det_thresh
=
det_thresh
self
.
track_buffer
=
track_buffer
self
.
min_box_area
=
min_box_area
self
.
vertical_ratio
=
vertical_ratio
self
.
tracked_thresh
=
tracked_thresh
self
.
r_tracked_thresh
=
r_tracked_thresh
self
.
unconfirmed_thresh
=
unconfirmed_thresh
self
.
motion
=
motion
self
.
conf_thres
=
conf_thres
self
.
metric_type
=
metric_type
self
.
frame_id
=
0
self
.
tracked_stracks
=
[]
self
.
lost_stracks
=
[]
self
.
removed_stracks
=
[]
self
.
max_time_lost
=
0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def
update
(
self
,
pred_dets
,
pred_embs
):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles
lost, removed, refound and active tracklets.
Args:
pred_dets (Tensor): Detection results of the image, shape is [N, 5].
pred_embs (Tensor): Embedding results of the image, shape is [N, 512].
Return:
output_stracks (list): The list contains information regarding the
online_tracklets for the recieved image tensor.
"""
self
.
frame_id
+=
1
activated_starcks
=
[]
# for storing active tracks, for the current frame
refind_stracks
=
[]
# Lost Tracks whose detections are obtained in the current frame
lost_stracks
=
[]
# The tracks which are not obtained in the current frame but are not
# removed. (Lost for some time lesser than the threshold for removing)
removed_stracks
=
[]
remain_inds
=
paddle
.
nonzero
(
pred_dets
[:,
4
]
>
self
.
conf_thres
)
if
remain_inds
.
shape
[
0
]
==
0
:
pred_dets
=
paddle
.
zeros
([
0
,
1
])
pred_embs
=
paddle
.
zeros
([
0
,
1
])
else
:
pred_dets
=
paddle
.
gather
(
pred_dets
,
remain_inds
)
pred_embs
=
paddle
.
gather
(
pred_embs
,
remain_inds
)
# Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]]
empty_pred
=
True
if
len
(
pred_dets
)
==
1
and
paddle
.
sum
(
pred_dets
)
==
0.0
else
False
""" Step 1: Network forward, get detections & embeddings"""
if
len
(
pred_dets
)
>
0
and
not
empty_pred
:
pred_dets
=
pred_dets
.
numpy
()
pred_embs
=
pred_embs
.
numpy
()
detections
=
[
STrack
(
STrack
.
tlbr_to_tlwh
(
tlbrs
[:
4
]),
tlbrs
[
4
],
f
,
30
)
for
(
tlbrs
,
f
)
in
zip
(
pred_dets
,
pred_embs
)
]
else
:
detections
=
[]
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed
=
[]
tracked_stracks
=
[]
# type: list[STrack]
for
track
in
self
.
tracked_stracks
:
if
not
track
.
is_activated
:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed
.
append
(
track
)
else
:
# Active tracks are added to the local list 'tracked_stracks'
tracked_stracks
.
append
(
track
)
""" Step 2: First association, with embedding"""
# Combining currently tracked_stracks and lost_stracks
strack_pool
=
joint_stracks
(
tracked_stracks
,
self
.
lost_stracks
)
# Predict the current location with KF
STrack
.
multi_predict
(
strack_pool
,
self
.
motion
)
dists
=
matching
.
embedding_distance
(
strack_pool
,
detections
,
metric
=
self
.
metric_type
)
dists
=
matching
.
fuse_motion
(
self
.
motion
,
dists
,
strack_pool
,
detections
)
# The dists is the list of distances of the detection with the tracks in strack_pool
matches
,
u_track
,
u_detection
=
matching
.
linear_assignment
(
dists
,
thresh
=
self
.
tracked_thresh
)
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
for
itracked
,
idet
in
matches
:
# itracked is the id of the track and idet is the detection
track
=
strack_pool
[
itracked
]
det
=
detections
[
idet
]
if
track
.
state
==
TrackState
.
Tracked
:
# If the track is active, add the detection to the track
track
.
update
(
detections
[
idet
],
self
.
frame_id
)
activated_starcks
.
append
(
track
)
else
:
# We have obtained a detection from a track which is not active,
# hence put the track in refind_stracks list
track
.
re_activate
(
det
,
self
.
frame_id
,
new_id
=
False
)
refind_stracks
.
append
(
track
)
# None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU"""
detections
=
[
detections
[
i
]
for
i
in
u_detection
]
# detections is now a list of the unmatched detections
r_tracked_stracks
=
[]
# This is container for stracks which were tracked till the previous
# frame but no detection was found for it in the current frame.
for
i
in
u_track
:
if
strack_pool
[
i
].
state
==
TrackState
.
Tracked
:
r_tracked_stracks
.
append
(
strack_pool
[
i
])
dists
=
matching
.
iou_distance
(
r_tracked_stracks
,
detections
)
matches
,
u_track
,
u_detection
=
matching
.
linear_assignment
(
dists
,
thresh
=
self
.
r_tracked_thresh
)
# matches is the list of detections which matched with corresponding
# tracks by IOU distance method.
for
itracked
,
idet
in
matches
:
track
=
r_tracked_stracks
[
itracked
]
det
=
detections
[
idet
]
if
track
.
state
==
TrackState
.
Tracked
:
track
.
update
(
det
,
self
.
frame_id
)
activated_starcks
.
append
(
track
)
else
:
track
.
re_activate
(
det
,
self
.
frame_id
,
new_id
=
False
)
refind_stracks
.
append
(
track
)
# Same process done for some unmatched detections, but now considering IOU_distance as measure
for
it
in
u_track
:
track
=
r_tracked_stracks
[
it
]
if
not
track
.
state
==
TrackState
.
Lost
:
track
.
mark_lost
()
lost_stracks
.
append
(
track
)
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections
=
[
detections
[
i
]
for
i
in
u_detection
]
dists
=
matching
.
iou_distance
(
unconfirmed
,
detections
)
matches
,
u_unconfirmed
,
u_detection
=
matching
.
linear_assignment
(
dists
,
thresh
=
self
.
unconfirmed_thresh
)
for
itracked
,
idet
in
matches
:
unconfirmed
[
itracked
].
update
(
detections
[
idet
],
self
.
frame_id
)
activated_starcks
.
append
(
unconfirmed
[
itracked
])
# The tracks which are yet not matched
for
it
in
u_unconfirmed
:
track
=
unconfirmed
[
it
]
track
.
mark_removed
()
removed_stracks
.
append
(
track
)
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
""" Step 4: Init new stracks"""
for
inew
in
u_detection
:
track
=
detections
[
inew
]
if
track
.
score
<
self
.
det_thresh
:
continue
track
.
activate
(
self
.
motion
,
self
.
frame_id
)
activated_starcks
.
append
(
track
)
""" Step 5: Update state"""
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
for
track
in
self
.
lost_stracks
:
if
self
.
frame_id
-
track
.
end_frame
>
self
.
max_time_lost
:
track
.
mark_removed
()
removed_stracks
.
append
(
track
)
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
self
.
tracked_stracks
=
[
t
for
t
in
self
.
tracked_stracks
if
t
.
state
==
TrackState
.
Tracked
]
self
.
tracked_stracks
=
joint_stracks
(
self
.
tracked_stracks
,
activated_starcks
)
self
.
tracked_stracks
=
joint_stracks
(
self
.
tracked_stracks
,
refind_stracks
)
self
.
lost_stracks
=
sub_stracks
(
self
.
lost_stracks
,
self
.
tracked_stracks
)
self
.
lost_stracks
.
extend
(
lost_stracks
)
self
.
lost_stracks
=
sub_stracks
(
self
.
lost_stracks
,
self
.
removed_stracks
)
self
.
removed_stracks
.
extend
(
removed_stracks
)
self
.
tracked_stracks
,
self
.
lost_stracks
=
remove_duplicate_stracks
(
self
.
tracked_stracks
,
self
.
lost_stracks
)
# get scores of lost tracks
output_stracks
=
[
track
for
track
in
self
.
tracked_stracks
if
track
.
is_activated
]
logger
.
debug
(
'===========Frame {}=========='
.
format
(
self
.
frame_id
))
logger
.
debug
(
'Activated: {}'
.
format
([
track
.
track_id
for
track
in
activated_starcks
]))
logger
.
debug
(
'Refind: {}'
.
format
([
track
.
track_id
for
track
in
refind_stracks
]))
logger
.
debug
(
'Lost: {}'
.
format
([
track
.
track_id
for
track
in
lost_stracks
]))
logger
.
debug
(
'Removed: {}'
.
format
([
track
.
track_id
for
track
in
removed_stracks
]))
return
output_stracks
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/utils.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
time
import
paddle
import
numpy
as
np
__all__
=
[
'Timer'
,
'Detection'
,
'load_det_results'
,
'preprocess_reid'
,
'get_crops'
,
'clip_box'
,
'scale_coords'
,
]
class
Timer
(
object
):
"""
This class used to compute and print the current FPS while evaling.
"""
def
__init__
(
self
):
self
.
total_time
=
0.
self
.
calls
=
0
self
.
start_time
=
0.
self
.
diff
=
0.
self
.
average_time
=
0.
self
.
duration
=
0.
def
tic
(
self
):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self
.
start_time
=
time
.
time
()
def
toc
(
self
,
average
=
True
):
self
.
diff
=
time
.
time
()
-
self
.
start_time
self
.
total_time
+=
self
.
diff
self
.
calls
+=
1
self
.
average_time
=
self
.
total_time
/
self
.
calls
if
average
:
self
.
duration
=
self
.
average_time
else
:
self
.
duration
=
self
.
diff
return
self
.
duration
def
clear
(
self
):
self
.
total_time
=
0.
self
.
calls
=
0
self
.
start_time
=
0.
self
.
diff
=
0.
self
.
average_time
=
0.
self
.
duration
=
0.
class
Detection
(
object
):
"""
This class represents a bounding box detection in a single image.
Args:
tlwh (ndarray): Bounding box in format `(top left x, top left y,
width, height)`.
confidence (ndarray): Detector confidence score.
feature (Tensor): A feature vector that describes the object
contained in this image.
"""
def
__init__
(
self
,
tlwh
,
confidence
,
feature
):
self
.
tlwh
=
np
.
asarray
(
tlwh
,
dtype
=
np
.
float32
)
self
.
confidence
=
np
.
asarray
(
confidence
,
dtype
=
np
.
float32
)
self
.
feature
=
feature
def
to_tlbr
(
self
):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret
=
self
.
tlwh
.
copy
()
ret
[
2
:]
+=
ret
[:
2
]
return
ret
def
to_xyah
(
self
):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret
=
self
.
tlwh
.
copy
()
ret
[:
2
]
+=
ret
[
2
:]
/
2
ret
[
2
]
/=
ret
[
3
]
return
ret
def
load_det_results
(
det_file
,
num_frames
):
assert
os
.
path
.
exists
(
det_file
)
and
os
.
path
.
isfile
(
det_file
),
\
'Error: det_file: {} not exist or not a file.'
.
format
(
det_file
)
labels
=
np
.
loadtxt
(
det_file
,
dtype
=
'float32'
,
delimiter
=
','
)
results_list
=
[]
for
frame_i
in
range
(
0
,
num_frames
):
results
=
{
'bbox'
:
[],
'score'
:
[]}
lables_with_frame
=
labels
[
labels
[:,
0
]
==
frame_i
+
1
]
for
l
in
lables_with_frame
:
results
[
'bbox'
].
append
(
l
[
1
:
5
])
results
[
'score'
].
append
(
l
[
5
])
results_list
.
append
(
results
)
return
results_list
def
scale_coords
(
coords
,
input_shape
,
im_shape
,
scale_factor
):
im_shape
=
im_shape
.
numpy
()[
0
]
ratio
=
scale_factor
[
0
][
0
]
pad_w
=
(
input_shape
[
1
]
-
int
(
im_shape
[
1
]))
/
2
pad_h
=
(
input_shape
[
0
]
-
int
(
im_shape
[
0
]))
/
2
coords
=
paddle
.
cast
(
coords
,
'float32'
)
coords
[:,
0
::
2
]
-=
pad_w
coords
[:,
1
::
2
]
-=
pad_h
coords
[:,
0
:
4
]
/=
ratio
coords
[:,
:
4
]
=
paddle
.
clip
(
coords
[:,
:
4
],
min
=
0
,
max
=
coords
[:,
:
4
].
max
())
return
coords
.
round
()
def
clip_box
(
xyxy
,
input_shape
,
im_shape
,
scale_factor
):
im_shape
=
im_shape
.
numpy
()[
0
]
ratio
=
scale_factor
.
numpy
()[
0
][
0
]
img0_shape
=
[
int
(
im_shape
[
0
]
/
ratio
),
int
(
im_shape
[
1
]
/
ratio
)]
xyxy
[:,
0
::
2
]
=
paddle
.
clip
(
xyxy
[:,
0
::
2
],
min
=
0
,
max
=
img0_shape
[
1
])
xyxy
[:,
1
::
2
]
=
paddle
.
clip
(
xyxy
[:,
1
::
2
],
min
=
0
,
max
=
img0_shape
[
0
])
return
xyxy
def
get_crops
(
xyxy
,
ori_img
,
pred_scores
,
w
,
h
):
crops
=
[]
keep_scores
=
[]
xyxy
=
xyxy
.
numpy
().
astype
(
np
.
int64
)
ori_img
=
ori_img
.
numpy
()
ori_img
=
np
.
squeeze
(
ori_img
,
axis
=
0
).
transpose
(
1
,
0
,
2
)
pred_scores
=
pred_scores
.
numpy
()
for
i
,
bbox
in
enumerate
(
xyxy
):
if
bbox
[
2
]
<=
bbox
[
0
]
or
bbox
[
3
]
<=
bbox
[
1
]:
continue
crop
=
ori_img
[
bbox
[
0
]:
bbox
[
2
],
bbox
[
1
]:
bbox
[
3
],
:]
crops
.
append
(
crop
)
keep_scores
.
append
(
pred_scores
[
i
])
if
len
(
crops
)
==
0
:
return
[],
[]
crops
=
preprocess_reid
(
crops
,
w
,
h
)
return
crops
,
keep_scores
def
preprocess_reid
(
imgs
,
w
=
64
,
h
=
192
,
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]):
im_batch
=
[]
for
img
in
imgs
:
img
=
cv2
.
resize
(
img
,
(
w
,
h
))
img
=
img
[:,
:,
::
-
1
].
astype
(
'float32'
).
transpose
((
2
,
0
,
1
))
/
255
img_mean
=
np
.
array
(
mean
).
reshape
((
3
,
1
,
1
))
img_std
=
np
.
array
(
std
).
reshape
((
3
,
1
,
1
))
img
-=
img_mean
img
/=
img_std
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
im_batch
.
append
(
img
)
im_batch
=
np
.
concatenate
(
im_batch
,
0
)
return
im_batch
modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/visualization.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
cv2
import
numpy
as
np
def
tlwhs_to_tlbrs
(
tlwhs
):
tlbrs
=
np
.
copy
(
tlwhs
)
if
len
(
tlbrs
)
==
0
:
return
tlbrs
tlbrs
[:,
2
]
+=
tlwhs
[:,
0
]
tlbrs
[:,
3
]
+=
tlwhs
[:,
1
]
return
tlbrs
def
get_color
(
idx
):
idx
=
idx
*
3
color
=
((
37
*
idx
)
%
255
,
(
17
*
idx
)
%
255
,
(
29
*
idx
)
%
255
)
return
color
def
resize_image
(
image
,
max_size
=
800
):
if
max
(
image
.
shape
[:
2
])
>
max_size
:
scale
=
float
(
max_size
)
/
max
(
image
.
shape
[:
2
])
image
=
cv2
.
resize
(
image
,
None
,
fx
=
scale
,
fy
=
scale
)
return
image
def
plot_tracking
(
image
,
tlwhs
,
obj_ids
,
scores
=
None
,
frame_id
=
0
,
fps
=
0.
,
ids2
=
None
):
im
=
np
.
ascontiguousarray
(
np
.
copy
(
image
))
im_h
,
im_w
=
im
.
shape
[:
2
]
top_view
=
np
.
zeros
([
im_w
,
im_w
,
3
],
dtype
=
np
.
uint8
)
+
255
text_scale
=
max
(
1
,
image
.
shape
[
1
]
/
1600.
)
text_thickness
=
2
line_thickness
=
max
(
1
,
int
(
image
.
shape
[
1
]
/
500.
))
radius
=
max
(
5
,
int
(
im_w
/
140.
))
cv2
.
putText
(
im
,
'frame: %d fps: %.2f num: %d'
%
(
frame_id
,
fps
,
len
(
tlwhs
)),
(
0
,
int
(
15
*
text_scale
)),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
0
,
255
),
thickness
=
2
)
for
i
,
tlwh
in
enumerate
(
tlwhs
):
x1
,
y1
,
w
,
h
=
tlwh
intbox
=
tuple
(
map
(
int
,
(
x1
,
y1
,
x1
+
w
,
y1
+
h
)))
obj_id
=
int
(
obj_ids
[
i
])
id_text
=
'{}'
.
format
(
int
(
obj_id
))
if
ids2
is
not
None
:
id_text
=
id_text
+
', {}'
.
format
(
int
(
ids2
[
i
]))
_line_thickness
=
1
if
obj_id
<=
0
else
line_thickness
color
=
get_color
(
abs
(
obj_id
))
cv2
.
rectangle
(
im
,
intbox
[
0
:
2
],
intbox
[
2
:
4
],
color
=
color
,
thickness
=
line_thickness
)
cv2
.
putText
(
im
,
id_text
,
(
intbox
[
0
],
intbox
[
1
]
+
10
),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
0
,
255
),
thickness
=
text_thickness
)
if
scores
is
not
None
:
text
=
'{:.2f}'
.
format
(
float
(
scores
[
i
]))
cv2
.
putText
(
im
,
text
,
(
intbox
[
0
],
intbox
[
1
]
-
10
),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
255
,
255
),
thickness
=
text_thickness
)
return
im
def
plot_trajectory
(
image
,
tlwhs
,
track_ids
):
image
=
image
.
copy
()
for
one_tlwhs
,
track_id
in
zip
(
tlwhs
,
track_ids
):
color
=
get_color
(
int
(
track_id
))
for
tlwh
in
one_tlwhs
:
x1
,
y1
,
w
,
h
=
tuple
(
map
(
int
,
tlwh
))
cv2
.
circle
(
image
,
(
int
(
x1
+
0.5
*
w
),
int
(
y1
+
h
)),
2
,
color
,
thickness
=
2
)
return
image
def
plot_detections
(
image
,
tlbrs
,
scores
=
None
,
color
=
(
255
,
0
,
0
),
ids
=
None
):
im
=
np
.
copy
(
image
)
text_scale
=
max
(
1
,
image
.
shape
[
1
]
/
800.
)
thickness
=
2
if
text_scale
>
1.3
else
1
for
i
,
det
in
enumerate
(
tlbrs
):
x1
,
y1
,
x2
,
y2
=
np
.
asarray
(
det
[:
4
],
dtype
=
np
.
int
)
if
len
(
det
)
>=
7
:
label
=
'det'
if
det
[
5
]
>
0
else
'trk'
if
ids
is
not
None
:
text
=
'{}# {:.2f}: {:d}'
.
format
(
label
,
det
[
6
],
ids
[
i
])
cv2
.
putText
(
im
,
text
,
(
x1
,
y1
+
30
),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
255
,
255
),
thickness
=
thickness
)
else
:
text
=
'{}# {:.2f}'
.
format
(
label
,
det
[
6
])
if
scores
is
not
None
:
text
=
'{:.2f}'
.
format
(
scores
[
i
])
cv2
.
putText
(
im
,
text
,
(
x1
,
y1
+
30
),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
255
,
255
),
thickness
=
thickness
)
cv2
.
rectangle
(
im
,
(
x1
,
y1
),
(
x2
,
y2
),
color
,
2
)
return
im
modules/video/multiple_object_tracking/fairmot_dla34/tracker.py
浏览文件 @
d53f6412
...
...
@@ -16,18 +16,19 @@ import cv2
import
glob
import
paddle
import
numpy
as
np
import
collections
from
ppdet.core.workspace
import
create
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
from
ppdet.modeling.mot.utils
import
Detection
,
get_crops
,
scale_coords
,
clip_box
from
ppdet.modeling.mot.utils
import
Timer
,
load_det_results
from
ppdet.modeling.mot
import
visualization
as
mot_vis
from
ppdet.metrics
import
Metric
,
MOTMetric
,
KITTIMOTMetric
import
ppdet.utils.stats
as
stats
from
ppdet.engine.callbacks
import
Callback
,
ComposeCallback
from
ppdet.utils.logger
import
setup_logger
from
.dataset
import
MOTVideoStream
,
MOTImageStream
from
.utils
import
Timer
from
.modeling.mot.utils
import
Detection
,
get_crops
,
scale_coords
,
clip_box
from
.modeling.mot
import
visualization
as
mot_vis
logger
=
setup_logger
(
__name__
)
...
...
@@ -71,7 +72,6 @@ class StreamTracker(object):
timer
.
tic
()
pred_dets
,
pred_embs
=
self
.
model
(
data
)
online_targets
=
self
.
model
.
tracker
.
update
(
pred_dets
,
pred_embs
)
online_tlwhs
,
online_ids
=
[],
[]
online_scores
=
[]
for
t
in
online_targets
:
...
...
@@ -109,7 +109,6 @@ class StreamTracker(object):
timer
.
tic
()
pred_dets
,
pred_embs
=
self
.
model
(
data
)
online_targets
=
self
.
model
.
tracker
.
update
(
pred_dets
,
pred_embs
)
online_tlwhs
,
online_ids
=
[],
[]
online_scores
=
[]
for
t
in
online_targets
:
...
...
modules/video/multiple_object_tracking/fairmot_dla34/utils.py
0 → 100644
浏览文件 @
d53f6412
import
time
class
Timer
(
object
):
"""
This class used to compute and print the current FPS while evaling.
"""
def
__init__
(
self
):
self
.
total_time
=
0.
self
.
calls
=
0
self
.
start_time
=
0.
self
.
diff
=
0.
self
.
average_time
=
0.
self
.
duration
=
0.
def
tic
(
self
):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self
.
start_time
=
time
.
time
()
def
toc
(
self
,
average
=
True
):
self
.
diff
=
time
.
time
()
-
self
.
start_time
self
.
total_time
+=
self
.
diff
self
.
calls
+=
1
self
.
average_time
=
self
.
total_time
/
self
.
calls
if
average
:
self
.
duration
=
self
.
average_time
else
:
self
.
duration
=
self
.
diff
return
self
.
duration
def
clear
(
self
):
self
.
total_time
=
0.
self
.
calls
=
0
self
.
start_time
=
0.
self
.
diff
=
0.
self
.
average_time
=
0.
self
.
duration
=
0.
modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml
浏览文件 @
d53f6412
...
...
@@ -5,7 +5,7 @@ find_unused_parameters: True
JDE
:
detector
:
YOLOv3
reid
:
JDEEmbeddingHead
tracker
:
JDETracker
tracker
:
Frozen
JDETracker
YOLOv3
:
backbone
:
DarkNet
...
...
modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml
浏览文件 @
d53f6412
...
...
@@ -9,7 +9,7 @@ _BASE_: [
JDE
:
detector
:
YOLOv3
reid
:
JDEEmbeddingHead
tracker
:
JDETracker
tracker
:
Frozen
JDETracker
YOLOv3
:
backbone
:
DarkNet
...
...
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/__init__.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
matching
from
.
import
tracker
from
.
import
motion
from
.
import
visualization
from
.
import
utils
from
.matching
import
*
from
.tracker
import
*
from
.motion
import
*
from
.visualization
import
*
from
.utils
import
*
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/__init__.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
jde_matching
from
.
import
deepsort_matching
from
.jde_matching
import
*
from
.deepsort_matching
import
*
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/deepsort_matching.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/tree/master/deep_sort
"""
import
numpy
as
np
from
scipy.optimize
import
linear_sum_assignment
from
..motion
import
kalman_filter
INFTY_COST
=
1e+5
__all__
=
[
'iou_1toN'
,
'iou_cost'
,
'_nn_euclidean_distance'
,
'_nn_cosine_distance'
,
'NearestNeighborDistanceMetric'
,
'min_cost_matching'
,
'matching_cascade'
,
'gate_cost_matrix'
,
]
def
iou_1toN
(
bbox
,
candidates
):
"""
Computer intersection over union (IoU) by one box to N candidates.
Args:
bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`.
candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the
same format as `bbox`.
Returns:
ious (ndarray): The intersection over union in [0, 1] between the `bbox`
and each candidate. A higher score means a larger fraction of the
`bbox` is occluded by the candidate.
"""
bbox_tl
=
bbox
[:
2
]
bbox_br
=
bbox
[:
2
]
+
bbox
[
2
:]
candidates_tl
=
candidates
[:,
:
2
]
candidates_br
=
candidates
[:,
:
2
]
+
candidates
[:,
2
:]
tl
=
np
.
c_
[
np
.
maximum
(
bbox_tl
[
0
],
candidates_tl
[:,
0
])[:,
np
.
newaxis
],
np
.
maximum
(
bbox_tl
[
1
],
candidates_tl
[:,
1
])[:,
np
.
newaxis
]]
br
=
np
.
c_
[
np
.
minimum
(
bbox_br
[
0
],
candidates_br
[:,
0
])[:,
np
.
newaxis
],
np
.
minimum
(
bbox_br
[
1
],
candidates_br
[:,
1
])[:,
np
.
newaxis
]]
wh
=
np
.
maximum
(
0.
,
br
-
tl
)
area_intersection
=
wh
.
prod
(
axis
=
1
)
area_bbox
=
bbox
[
2
:].
prod
()
area_candidates
=
candidates
[:,
2
:].
prod
(
axis
=
1
)
ious
=
area_intersection
/
(
area_bbox
+
area_candidates
-
area_intersection
)
return
ious
def
iou_cost
(
tracks
,
detections
,
track_indices
=
None
,
detection_indices
=
None
):
"""
IoU distance metric.
Args:
tracks (list[Track]): A list of tracks.
detections (list[Detection]): A list of detections.
track_indices (Optional[list[int]]): A list of indices to tracks that
should be matched. Defaults to all `tracks`.
detection_indices (Optional[list[int]]): A list of indices to detections
that should be matched. Defaults to all `detections`.
Returns:
cost_matrix (ndarray): A cost matrix of shape len(track_indices),
len(detection_indices) where entry (i, j) is
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
"""
if
track_indices
is
None
:
track_indices
=
np
.
arange
(
len
(
tracks
))
if
detection_indices
is
None
:
detection_indices
=
np
.
arange
(
len
(
detections
))
cost_matrix
=
np
.
zeros
((
len
(
track_indices
),
len
(
detection_indices
)))
for
row
,
track_idx
in
enumerate
(
track_indices
):
if
tracks
[
track_idx
].
time_since_update
>
1
:
cost_matrix
[
row
,
:]
=
1e+5
continue
bbox
=
tracks
[
track_idx
].
to_tlwh
()
candidates
=
np
.
asarray
([
detections
[
i
].
tlwh
for
i
in
detection_indices
])
cost_matrix
[
row
,
:]
=
1.
-
iou_1toN
(
bbox
,
candidates
)
return
cost_matrix
def
_nn_euclidean_distance
(
s
,
q
):
"""
Compute pair-wise squared (Euclidean) distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s
,
q
=
np
.
asarray
(
s
),
np
.
asarray
(
q
)
if
len
(
s
)
==
0
or
len
(
q
)
==
0
:
return
np
.
zeros
((
len
(
s
),
len
(
q
)))
s2
,
q2
=
np
.
square
(
s
).
sum
(
axis
=
1
),
np
.
square
(
q
).
sum
(
axis
=
1
)
distances
=
-
2.
*
np
.
dot
(
s
,
q
.
T
)
+
s2
[:,
None
]
+
q2
[
None
,
:]
distances
=
np
.
clip
(
distances
,
0.
,
float
(
np
.
inf
))
return
np
.
maximum
(
0.0
,
distances
.
min
(
axis
=
0
))
def
_nn_cosine_distance
(
s
,
q
):
"""
Compute pair-wise cosine distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s
=
np
.
asarray
(
s
)
/
np
.
linalg
.
norm
(
s
,
axis
=
1
,
keepdims
=
True
)
q
=
np
.
asarray
(
q
)
/
np
.
linalg
.
norm
(
q
,
axis
=
1
,
keepdims
=
True
)
distances
=
1.
-
np
.
dot
(
s
,
q
.
T
)
return
distances
.
min
(
axis
=
0
)
class
NearestNeighborDistanceMetric
(
object
):
"""
A nearest neighbor distance metric that, for each target, returns
the closest distance to any sample that has been observed so far.
Args:
metric (str): Either "euclidean" or "cosine".
matching_threshold (float): The matching threshold. Samples with larger
distance are considered an invalid match.
budget (Optional[int]): If not None, fix samples per class to at most
this number. Removes the oldest samples when the budget is reached.
Attributes:
samples (Dict[int -> List[ndarray]]): A dictionary that maps from target
identities to the list of samples that have been observed so far.
"""
def
__init__
(
self
,
metric
,
matching_threshold
,
budget
=
None
):
if
metric
==
"euclidean"
:
self
.
_metric
=
_nn_euclidean_distance
elif
metric
==
"cosine"
:
self
.
_metric
=
_nn_cosine_distance
else
:
raise
ValueError
(
"Invalid metric; must be either 'euclidean' or 'cosine'"
)
self
.
matching_threshold
=
matching_threshold
self
.
budget
=
budget
self
.
samples
=
{}
def
partial_fit
(
self
,
features
,
targets
,
active_targets
):
"""
Update the distance metric with new data.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (ndarray): An integer array of associated target identities.
active_targets (List[int]): A list of targets that are currently
present in the scene.
"""
for
feature
,
target
in
zip
(
features
,
targets
):
self
.
samples
.
setdefault
(
target
,
[]).
append
(
feature
)
if
self
.
budget
is
not
None
:
self
.
samples
[
target
]
=
self
.
samples
[
target
][
-
self
.
budget
:]
self
.
samples
=
{
k
:
self
.
samples
[
k
]
for
k
in
active_targets
}
def
distance
(
self
,
features
,
targets
):
"""
Compute distance between features and targets.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (list[int]): A list of targets to match the given `features` against.
Returns:
cost_matrix (ndarray): a cost matrix of shape len(targets), len(features),
where element (i, j) contains the closest squared distance between
`targets[i]` and `features[j]`.
"""
cost_matrix
=
np
.
zeros
((
len
(
targets
),
len
(
features
)))
for
i
,
target
in
enumerate
(
targets
):
cost_matrix
[
i
,
:]
=
self
.
_metric
(
self
.
samples
[
target
],
features
)
return
cost_matrix
def
min_cost_matching
(
distance_metric
,
max_distance
,
tracks
,
detections
,
track_indices
=
None
,
detection_indices
=
None
):
"""
Solve linear assignment problem.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if
track_indices
is
None
:
track_indices
=
np
.
arange
(
len
(
tracks
))
if
detection_indices
is
None
:
detection_indices
=
np
.
arange
(
len
(
detections
))
if
len
(
detection_indices
)
==
0
or
len
(
track_indices
)
==
0
:
return
[],
track_indices
,
detection_indices
# Nothing to match.
cost_matrix
=
distance_metric
(
tracks
,
detections
,
track_indices
,
detection_indices
)
cost_matrix
[
cost_matrix
>
max_distance
]
=
max_distance
+
1e-5
indices
=
linear_sum_assignment
(
cost_matrix
)
matches
,
unmatched_tracks
,
unmatched_detections
=
[],
[],
[]
for
col
,
detection_idx
in
enumerate
(
detection_indices
):
if
col
not
in
indices
[
1
]:
unmatched_detections
.
append
(
detection_idx
)
for
row
,
track_idx
in
enumerate
(
track_indices
):
if
row
not
in
indices
[
0
]:
unmatched_tracks
.
append
(
track_idx
)
for
row
,
col
in
zip
(
indices
[
0
],
indices
[
1
]):
track_idx
=
track_indices
[
row
]
detection_idx
=
detection_indices
[
col
]
if
cost_matrix
[
row
,
col
]
>
max_distance
:
unmatched_tracks
.
append
(
track_idx
)
unmatched_detections
.
append
(
detection_idx
)
else
:
matches
.
append
((
track_idx
,
detection_idx
))
return
matches
,
unmatched_tracks
,
unmatched_detections
def
matching_cascade
(
distance_metric
,
max_distance
,
cascade_depth
,
tracks
,
detections
,
track_indices
=
None
,
detection_indices
=
None
):
"""
Run matching cascade.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
cascade_depth (int): The cascade depth, should be se to the maximum
track age.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if
track_indices
is
None
:
track_indices
=
list
(
range
(
len
(
tracks
)))
if
detection_indices
is
None
:
detection_indices
=
list
(
range
(
len
(
detections
)))
unmatched_detections
=
detection_indices
matches
=
[]
for
level
in
range
(
cascade_depth
):
if
len
(
unmatched_detections
)
==
0
:
# No detections left
break
track_indices_l
=
[
k
for
k
in
track_indices
if
tracks
[
k
].
time_since_update
==
1
+
level
]
if
len
(
track_indices_l
)
==
0
:
# Nothing to match at this level
continue
matches_l
,
_
,
unmatched_detections
=
\
min_cost_matching
(
distance_metric
,
max_distance
,
tracks
,
detections
,
track_indices_l
,
unmatched_detections
)
matches
+=
matches_l
unmatched_tracks
=
list
(
set
(
track_indices
)
-
set
(
k
for
k
,
_
in
matches
))
return
matches
,
unmatched_tracks
,
unmatched_detections
def
gate_cost_matrix
(
kf
,
cost_matrix
,
tracks
,
detections
,
track_indices
,
detection_indices
,
gated_cost
=
INFTY_COST
,
only_position
=
False
):
"""
Invalidate infeasible entries in cost matrix based on the state
distributions obtained by Kalman filtering.
Args:
kf (object): The Kalman filter.
cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the
number of track indices and M is the number of detection indices,
such that entry (i, j) is the association cost between
`tracks[track_indices[i]]` and `detections[detection_indices[j]]`.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (List[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
gated_cost (Optional[float]): Entries in the cost matrix corresponding
to infeasible associations are set this value. Defaults to a very
large value.
only_position (Optional[bool]): If True, only the x, y position of the
state distribution is considered during gating. Default False.
"""
gating_dim
=
2
if
only_position
else
4
gating_threshold
=
kalman_filter
.
chi2inv95
[
gating_dim
]
measurements
=
np
.
asarray
([
detections
[
i
].
to_xyah
()
for
i
in
detection_indices
])
for
row
,
track_idx
in
enumerate
(
track_indices
):
track
=
tracks
[
track_idx
]
gating_distance
=
kf
.
gating_distance
(
track
.
mean
,
track
.
covariance
,
measurements
,
only_position
)
cost_matrix
[
row
,
gating_distance
>
gating_threshold
]
=
gated_cost
return
cost_matrix
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/jde_matching.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py
"""
import
lap
import
scipy
import
numpy
as
np
from
scipy.spatial.distance
import
cdist
from
..motion
import
kalman_filter
from
ppdet.utils.logger
import
setup_logger
logger
=
setup_logger
(
__name__
)
__all__
=
[
'merge_matches'
,
'linear_assignment'
,
'cython_bbox_ious'
,
'iou_distance'
,
'embedding_distance'
,
'fuse_motion'
,
]
def
merge_matches
(
m1
,
m2
,
shape
):
O
,
P
,
Q
=
shape
m1
=
np
.
asarray
(
m1
)
m2
=
np
.
asarray
(
m2
)
M1
=
scipy
.
sparse
.
coo_matrix
((
np
.
ones
(
len
(
m1
)),
(
m1
[:,
0
],
m1
[:,
1
])),
shape
=
(
O
,
P
))
M2
=
scipy
.
sparse
.
coo_matrix
((
np
.
ones
(
len
(
m2
)),
(
m2
[:,
0
],
m2
[:,
1
])),
shape
=
(
P
,
Q
))
mask
=
M1
*
M2
match
=
mask
.
nonzero
()
match
=
list
(
zip
(
match
[
0
],
match
[
1
]))
unmatched_O
=
tuple
(
set
(
range
(
O
))
-
set
([
i
for
i
,
j
in
match
]))
unmatched_Q
=
tuple
(
set
(
range
(
Q
))
-
set
([
j
for
i
,
j
in
match
]))
return
match
,
unmatched_O
,
unmatched_Q
def
linear_assignment
(
cost_matrix
,
thresh
):
if
cost_matrix
.
size
==
0
:
return
np
.
empty
((
0
,
2
),
dtype
=
int
),
tuple
(
range
(
cost_matrix
.
shape
[
0
])),
tuple
(
range
(
cost_matrix
.
shape
[
1
]))
matches
,
unmatched_a
,
unmatched_b
=
[],
[],
[]
cost
,
x
,
y
=
lap
.
lapjv
(
cost_matrix
,
extend_cost
=
True
,
cost_limit
=
thresh
)
for
ix
,
mx
in
enumerate
(
x
):
if
mx
>=
0
:
matches
.
append
([
ix
,
mx
])
unmatched_a
=
np
.
where
(
x
<
0
)[
0
]
unmatched_b
=
np
.
where
(
y
<
0
)[
0
]
matches
=
np
.
asarray
(
matches
)
return
matches
,
unmatched_a
,
unmatched_b
def
cython_bbox_ious
(
atlbrs
,
btlbrs
):
ious
=
np
.
zeros
((
len
(
atlbrs
),
len
(
btlbrs
)),
dtype
=
np
.
float
)
if
ious
.
size
==
0
:
return
ious
try
:
import
cython_bbox
except
Exception
as
e
:
logger
.
error
(
'cython_bbox not found, please install cython_bbox.'
'for example: `pip install cython_bbox`.'
)
raise
e
ious
=
cython_bbox
.
bbox_overlaps
(
np
.
ascontiguousarray
(
atlbrs
,
dtype
=
np
.
float
),
np
.
ascontiguousarray
(
btlbrs
,
dtype
=
np
.
float
))
return
ious
def
iou_distance
(
atracks
,
btracks
):
"""
Compute cost based on IoU between two list[STrack].
"""
if
(
len
(
atracks
)
>
0
and
isinstance
(
atracks
[
0
],
np
.
ndarray
))
or
(
len
(
btracks
)
>
0
and
isinstance
(
btracks
[
0
],
np
.
ndarray
)):
atlbrs
=
atracks
btlbrs
=
btracks
else
:
atlbrs
=
[
track
.
tlbr
for
track
in
atracks
]
btlbrs
=
[
track
.
tlbr
for
track
in
btracks
]
_ious
=
cython_bbox_ious
(
atlbrs
,
btlbrs
)
cost_matrix
=
1
-
_ious
return
cost_matrix
def
embedding_distance
(
tracks
,
detections
,
metric
=
'euclidean'
):
"""
Compute cost based on features between two list[STrack].
"""
cost_matrix
=
np
.
zeros
((
len
(
tracks
),
len
(
detections
)),
dtype
=
np
.
float
)
if
cost_matrix
.
size
==
0
:
return
cost_matrix
det_features
=
np
.
asarray
([
track
.
curr_feat
for
track
in
detections
],
dtype
=
np
.
float
)
track_features
=
np
.
asarray
([
track
.
smooth_feat
for
track
in
tracks
],
dtype
=
np
.
float
)
cost_matrix
=
np
.
maximum
(
0.0
,
cdist
(
track_features
,
det_features
,
metric
))
# Nomalized features
return
cost_matrix
def
fuse_motion
(
kf
,
cost_matrix
,
tracks
,
detections
,
only_position
=
False
,
lambda_
=
0.98
):
if
cost_matrix
.
size
==
0
:
return
cost_matrix
gating_dim
=
2
if
only_position
else
4
gating_threshold
=
kalman_filter
.
chi2inv95
[
gating_dim
]
measurements
=
np
.
asarray
([
det
.
to_xyah
()
for
det
in
detections
])
for
row
,
track
in
enumerate
(
tracks
):
gating_distance
=
kf
.
gating_distance
(
track
.
mean
,
track
.
covariance
,
measurements
,
only_position
,
metric
=
'maha'
)
cost_matrix
[
row
,
gating_distance
>
gating_threshold
]
=
np
.
inf
cost_matrix
[
row
]
=
lambda_
*
cost_matrix
[
row
]
+
(
1
-
lambda_
)
*
gating_distance
return
cost_matrix
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/__init__.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
kalman_filter
from
.kalman_filter
import
*
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/kalman_filter.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/kalman_filter.py
"""
import
numpy
as
np
import
scipy.linalg
__all__
=
[
'KalmanFilter'
]
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
function and used as Mahalanobis gating threshold.
"""
chi2inv95
=
{
1
:
3.8415
,
2
:
5.9915
,
3
:
7.8147
,
4
:
9.4877
,
5
:
11.070
,
6
:
12.592
,
7
:
14.067
,
8
:
15.507
,
9
:
16.919
}
class
KalmanFilter
(
object
):
"""
A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, a, h, vx, vy, va, vh
contains the bounding box center position (x, y), aspect ratio a, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, a, h) is taken as direct observation of the state space (linear
observation model).
"""
def
__init__
(
self
):
ndim
,
dt
=
4
,
1.
# Create Kalman filter model matrices.
self
.
_motion_mat
=
np
.
eye
(
2
*
ndim
,
2
*
ndim
)
for
i
in
range
(
ndim
):
self
.
_motion_mat
[
i
,
ndim
+
i
]
=
dt
self
.
_update_mat
=
np
.
eye
(
ndim
,
2
*
ndim
)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
self
.
_std_weight_position
=
1.
/
20
self
.
_std_weight_velocity
=
1.
/
160
def
initiate
(
self
,
measurement
):
"""
Create track from unassociated measurement.
Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with
center position (x, y), aspect ratio a, and height h.
Returns:
The mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are
initialized to 0 mean.
"""
mean_pos
=
measurement
mean_vel
=
np
.
zeros_like
(
mean_pos
)
mean
=
np
.
r_
[
mean_pos
,
mean_vel
]
std
=
[
2
*
self
.
_std_weight_position
*
measurement
[
3
],
2
*
self
.
_std_weight_position
*
measurement
[
3
],
1e-2
,
2
*
self
.
_std_weight_position
*
measurement
[
3
],
10
*
self
.
_std_weight_velocity
*
measurement
[
3
],
10
*
self
.
_std_weight_velocity
*
measurement
[
3
],
1e-5
,
10
*
self
.
_std_weight_velocity
*
measurement
[
3
]
]
covariance
=
np
.
diag
(
np
.
square
(
std
))
return
mean
,
covariance
def
predict
(
self
,
mean
,
covariance
):
"""
Run Kalman filter prediction step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state
at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the
object state at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos
=
[
self
.
_std_weight_position
*
mean
[
3
],
self
.
_std_weight_position
*
mean
[
3
],
1e-2
,
self
.
_std_weight_position
*
mean
[
3
]
]
std_vel
=
[
self
.
_std_weight_velocity
*
mean
[
3
],
self
.
_std_weight_velocity
*
mean
[
3
],
1e-5
,
self
.
_std_weight_velocity
*
mean
[
3
]
]
motion_cov
=
np
.
diag
(
np
.
square
(
np
.
r_
[
std_pos
,
std_vel
]))
#mean = np.dot(self._motion_mat, mean)
mean
=
np
.
dot
(
mean
,
self
.
_motion_mat
.
T
)
covariance
=
np
.
linalg
.
multi_dot
((
self
.
_motion_mat
,
covariance
,
self
.
_motion_mat
.
T
))
+
motion_cov
return
mean
,
covariance
def
project
(
self
,
mean
,
covariance
):
"""
Project state distribution to measurement space.
Args
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns:
The projected mean and covariance matrix of the given state estimate.
"""
std
=
[
self
.
_std_weight_position
*
mean
[
3
],
self
.
_std_weight_position
*
mean
[
3
],
1e-1
,
self
.
_std_weight_position
*
mean
[
3
]
]
innovation_cov
=
np
.
diag
(
np
.
square
(
std
))
mean
=
np
.
dot
(
self
.
_update_mat
,
mean
)
covariance
=
np
.
linalg
.
multi_dot
((
self
.
_update_mat
,
covariance
,
self
.
_update_mat
.
T
))
return
mean
,
covariance
+
innovation_cov
def
multi_predict
(
self
,
mean
,
covariance
):
"""
Run Kalman filter prediction step (Vectorized version).
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states
at the previous time step.
covariance (ndarray): The Nx8x8 dimensional covariance matrics of the
object states at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos
=
[
self
.
_std_weight_position
*
mean
[:,
3
],
self
.
_std_weight_position
*
mean
[:,
3
],
1e-2
*
np
.
ones_like
(
mean
[:,
3
]),
self
.
_std_weight_position
*
mean
[:,
3
]
]
std_vel
=
[
self
.
_std_weight_velocity
*
mean
[:,
3
],
self
.
_std_weight_velocity
*
mean
[:,
3
],
1e-5
*
np
.
ones_like
(
mean
[:,
3
]),
self
.
_std_weight_velocity
*
mean
[:,
3
]
]
sqr
=
np
.
square
(
np
.
r_
[
std_pos
,
std_vel
]).
T
motion_cov
=
[]
for
i
in
range
(
len
(
mean
)):
motion_cov
.
append
(
np
.
diag
(
sqr
[
i
]))
motion_cov
=
np
.
asarray
(
motion_cov
)
mean
=
np
.
dot
(
mean
,
self
.
_motion_mat
.
T
)
left
=
np
.
dot
(
self
.
_motion_mat
,
covariance
).
transpose
((
1
,
0
,
2
))
covariance
=
np
.
dot
(
left
,
self
.
_motion_mat
.
T
)
+
motion_cov
return
mean
,
covariance
def
update
(
self
,
mean
,
covariance
,
measurement
):
"""
Run Kalman filter correction step.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector
(x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Returns:
The measurement-corrected state distribution.
"""
projected_mean
,
projected_cov
=
self
.
project
(
mean
,
covariance
)
chol_factor
,
lower
=
scipy
.
linalg
.
cho_factor
(
projected_cov
,
lower
=
True
,
check_finite
=
False
)
kalman_gain
=
scipy
.
linalg
.
cho_solve
((
chol_factor
,
lower
),
np
.
dot
(
covariance
,
self
.
_update_mat
.
T
).
T
,
check_finite
=
False
).
T
innovation
=
measurement
-
projected_mean
new_mean
=
mean
+
np
.
dot
(
innovation
,
kalman_gain
.
T
)
new_covariance
=
covariance
-
np
.
linalg
.
multi_dot
((
kalman_gain
,
projected_cov
,
kalman_gain
.
T
))
return
new_mean
,
new_covariance
def
gating_distance
(
self
,
mean
,
covariance
,
measurements
,
only_position
=
False
,
metric
=
'maha'
):
"""
Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Args:
mean (ndarray): Mean vector over the state distribution (8
dimensional).
covariance (ndarray): Covariance of the state distribution (8x8
dimensional).
measurements (ndarray): An Nx4 dimensional matrix of N measurements,
each in format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
only_position (Optional[bool]): If True, distance computation is
done with respect to the bounding box center position only.
metric (str): Metric type, 'gaussian' or 'maha'.
Returns
An array of length N, where the i-th element contains the squared
Mahalanobis distance between (mean, covariance) and `measurements[i]`.
"""
mean
,
covariance
=
self
.
project
(
mean
,
covariance
)
if
only_position
:
mean
,
covariance
=
mean
[:
2
],
covariance
[:
2
,
:
2
]
measurements
=
measurements
[:,
:
2
]
d
=
measurements
-
mean
if
metric
==
'gaussian'
:
return
np
.
sum
(
d
*
d
,
axis
=
1
)
elif
metric
==
'maha'
:
cholesky_factor
=
np
.
linalg
.
cholesky
(
covariance
)
z
=
scipy
.
linalg
.
solve_triangular
(
cholesky_factor
,
d
.
T
,
lower
=
True
,
check_finite
=
False
,
overwrite_b
=
True
)
squared_maha
=
np
.
sum
(
z
*
z
,
axis
=
0
)
return
squared_maha
else
:
raise
ValueError
(
'invalid distance metric'
)
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/__init__.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
base_jde_tracker
from
.
import
base_sde_tracker
from
.
import
jde_tracker
from
.base_jde_tracker
import
*
from
.base_sde_tracker
import
*
from
.jde_tracker
import
*
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_jde_tracker.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import
numpy
as
np
from
collections
import
deque
,
OrderedDict
from
..matching
import
jde_matching
as
matching
from
ppdet.core.workspace
import
register
,
serializable
__all__
=
[
'TrackState'
,
'BaseTrack'
,
'STrack'
,
'joint_stracks'
,
'sub_stracks'
,
'remove_duplicate_stracks'
,
]
class
TrackState
(
object
):
New
=
0
Tracked
=
1
Lost
=
2
Removed
=
3
class
BaseTrack
(
object
):
_count
=
0
track_id
=
0
is_activated
=
False
state
=
TrackState
.
New
history
=
OrderedDict
()
features
=
[]
curr_feature
=
None
score
=
0
start_frame
=
0
frame_id
=
0
time_since_update
=
0
# multi-camera
location
=
(
np
.
inf
,
np
.
inf
)
@
property
def
end_frame
(
self
):
return
self
.
frame_id
@
staticmethod
def
next_id
():
BaseTrack
.
_count
+=
1
return
BaseTrack
.
_count
def
activate
(
self
,
*
args
):
raise
NotImplementedError
def
predict
(
self
):
raise
NotImplementedError
def
update
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
mark_lost
(
self
):
self
.
state
=
TrackState
.
Lost
def
mark_removed
(
self
):
self
.
state
=
TrackState
.
Removed
class
STrack
(
BaseTrack
):
def
__init__
(
self
,
tlwh
,
score
,
temp_feat
,
buffer_size
=
30
):
# wait activate
self
.
_tlwh
=
np
.
asarray
(
tlwh
,
dtype
=
np
.
float
)
self
.
kalman_filter
=
None
self
.
mean
,
self
.
covariance
=
None
,
None
self
.
is_activated
=
False
self
.
score
=
score
self
.
tracklet_len
=
0
self
.
smooth_feat
=
None
self
.
update_features
(
temp_feat
)
self
.
features
=
deque
([],
maxlen
=
buffer_size
)
self
.
alpha
=
0.9
def
update_features
(
self
,
feat
):
feat
/=
np
.
linalg
.
norm
(
feat
)
self
.
curr_feat
=
feat
if
self
.
smooth_feat
is
None
:
self
.
smooth_feat
=
feat
else
:
self
.
smooth_feat
=
self
.
alpha
*
self
.
smooth_feat
+
(
1
-
self
.
alpha
)
*
feat
self
.
features
.
append
(
feat
)
self
.
smooth_feat
/=
np
.
linalg
.
norm
(
self
.
smooth_feat
)
def
predict
(
self
):
mean_state
=
self
.
mean
.
copy
()
if
self
.
state
!=
TrackState
.
Tracked
:
mean_state
[
7
]
=
0
self
.
mean
,
self
.
covariance
=
self
.
kalman_filter
.
predict
(
mean_state
,
self
.
covariance
)
@
staticmethod
def
multi_predict
(
stracks
,
kalman_filter
):
if
len
(
stracks
)
>
0
:
multi_mean
=
np
.
asarray
([
st
.
mean
.
copy
()
for
st
in
stracks
])
multi_covariance
=
np
.
asarray
([
st
.
covariance
for
st
in
stracks
])
for
i
,
st
in
enumerate
(
stracks
):
if
st
.
state
!=
TrackState
.
Tracked
:
multi_mean
[
i
][
7
]
=
0
multi_mean
,
multi_covariance
=
kalman_filter
.
multi_predict
(
multi_mean
,
multi_covariance
)
for
i
,
(
mean
,
cov
)
in
enumerate
(
zip
(
multi_mean
,
multi_covariance
)):
stracks
[
i
].
mean
=
mean
stracks
[
i
].
covariance
=
cov
def
activate
(
self
,
kalman_filter
,
frame_id
):
"""Start a new tracklet"""
self
.
kalman_filter
=
kalman_filter
self
.
track_id
=
self
.
next_id
()
self
.
mean
,
self
.
covariance
=
self
.
kalman_filter
.
initiate
(
self
.
tlwh_to_xyah
(
self
.
_tlwh
))
self
.
tracklet_len
=
0
self
.
state
=
TrackState
.
Tracked
if
frame_id
==
1
:
self
.
is_activated
=
True
self
.
frame_id
=
frame_id
self
.
start_frame
=
frame_id
def
re_activate
(
self
,
new_track
,
frame_id
,
new_id
=
False
):
self
.
mean
,
self
.
covariance
=
self
.
kalman_filter
.
update
(
self
.
mean
,
self
.
covariance
,
self
.
tlwh_to_xyah
(
new_track
.
tlwh
))
self
.
update_features
(
new_track
.
curr_feat
)
self
.
tracklet_len
=
0
self
.
state
=
TrackState
.
Tracked
self
.
is_activated
=
True
self
.
frame_id
=
frame_id
if
new_id
:
self
.
track_id
=
self
.
next_id
()
def
update
(
self
,
new_track
,
frame_id
,
update_feature
=
True
):
self
.
frame_id
=
frame_id
self
.
tracklet_len
+=
1
new_tlwh
=
new_track
.
tlwh
self
.
mean
,
self
.
covariance
=
self
.
kalman_filter
.
update
(
self
.
mean
,
self
.
covariance
,
self
.
tlwh_to_xyah
(
new_tlwh
))
self
.
state
=
TrackState
.
Tracked
self
.
is_activated
=
True
self
.
score
=
new_track
.
score
if
update_feature
:
self
.
update_features
(
new_track
.
curr_feat
)
@
property
def
tlwh
(
self
):
"""
Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if
self
.
mean
is
None
:
return
self
.
_tlwh
.
copy
()
ret
=
self
.
mean
[:
4
].
copy
()
ret
[
2
]
*=
ret
[
3
]
ret
[:
2
]
-=
ret
[
2
:]
/
2
return
ret
@
property
def
tlbr
(
self
):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret
=
self
.
tlwh
.
copy
()
ret
[
2
:]
+=
ret
[:
2
]
return
ret
@
staticmethod
def
tlwh_to_xyah
(
tlwh
):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret
=
np
.
asarray
(
tlwh
).
copy
()
ret
[:
2
]
+=
ret
[
2
:]
/
2
ret
[
2
]
/=
ret
[
3
]
return
ret
def
to_xyah
(
self
):
return
self
.
tlwh_to_xyah
(
self
.
tlwh
)
@
staticmethod
def
tlbr_to_tlwh
(
tlbr
):
ret
=
np
.
asarray
(
tlbr
).
copy
()
ret
[
2
:]
-=
ret
[:
2
]
return
ret
@
staticmethod
def
tlwh_to_tlbr
(
tlwh
):
ret
=
np
.
asarray
(
tlwh
).
copy
()
ret
[
2
:]
+=
ret
[:
2
]
return
ret
def
__repr__
(
self
):
return
'OT_{}_({}-{})'
.
format
(
self
.
track_id
,
self
.
start_frame
,
self
.
end_frame
)
def
joint_stracks
(
tlista
,
tlistb
):
exists
=
{}
res
=
[]
for
t
in
tlista
:
exists
[
t
.
track_id
]
=
1
res
.
append
(
t
)
for
t
in
tlistb
:
tid
=
t
.
track_id
if
not
exists
.
get
(
tid
,
0
):
exists
[
tid
]
=
1
res
.
append
(
t
)
return
res
def
sub_stracks
(
tlista
,
tlistb
):
stracks
=
{}
for
t
in
tlista
:
stracks
[
t
.
track_id
]
=
t
for
t
in
tlistb
:
tid
=
t
.
track_id
if
stracks
.
get
(
tid
,
0
):
del
stracks
[
tid
]
return
list
(
stracks
.
values
())
def
remove_duplicate_stracks
(
stracksa
,
stracksb
):
pdist
=
matching
.
iou_distance
(
stracksa
,
stracksb
)
pairs
=
np
.
where
(
pdist
<
0.15
)
dupa
,
dupb
=
list
(),
list
()
for
p
,
q
in
zip
(
*
pairs
):
timep
=
stracksa
[
p
].
frame_id
-
stracksa
[
p
].
start_frame
timeq
=
stracksb
[
q
].
frame_id
-
stracksb
[
q
].
start_frame
if
timep
>
timeq
:
dupb
.
append
(
q
)
else
:
dupa
.
append
(
p
)
resa
=
[
t
for
i
,
t
in
enumerate
(
stracksa
)
if
not
i
in
dupa
]
resb
=
[
t
for
i
,
t
in
enumerate
(
stracksb
)
if
not
i
in
dupb
]
return
resa
,
resb
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_sde_tracker.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py
"""
from
ppdet.core.workspace
import
register
,
serializable
__all__
=
[
'TrackState'
,
'Track'
]
class
TrackState
(
object
):
"""
Enumeration type for the single target track state. Newly created tracks are
classified as `tentative` until enough evidence has been collected. Then,
the track state is changed to `confirmed`. Tracks that are no longer alive
are classified as `deleted` to mark them for removal from the set of active
tracks.
"""
Tentative
=
1
Confirmed
=
2
Deleted
=
3
class
Track
(
object
):
"""
A single target track with state space `(x, y, a, h)` and associated
velocities, where `(x, y)` is the center of the bounding box, `a` is the
aspect ratio and `h` is the height.
Args:
mean (ndarray): Mean vector of the initial state distribution.
covariance (ndarray): Covariance matrix of the initial state distribution.
track_id (int): A unique track identifier.
n_init (int): Number of consecutive detections before the track is confirmed.
The track state is set to `Deleted` if a miss occurs within the first
`n_init` frames.
max_age (int): The maximum number of consecutive misses before the track
state is set to `Deleted`.
feature (Optional[ndarray]): Feature vector of the detection this track
originates from. If not None, this feature is added to the `features` cache.
Attributes:
hits (int): Total number of measurement updates.
age (int): Total number of frames since first occurance.
time_since_update (int): Total number of frames since last measurement
update.
state (TrackState): The current track state.
features (List[ndarray]): A cache of features. On each measurement update,
the associated feature vector is added to this list.
"""
def
__init__
(
self
,
mean
,
covariance
,
track_id
,
n_init
,
max_age
,
feature
=
None
):
self
.
mean
=
mean
self
.
covariance
=
covariance
self
.
track_id
=
track_id
self
.
hits
=
1
self
.
age
=
1
self
.
time_since_update
=
0
self
.
state
=
TrackState
.
Tentative
self
.
features
=
[]
if
feature
is
not
None
:
self
.
features
.
append
(
feature
)
self
.
_n_init
=
n_init
self
.
_max_age
=
max_age
def
to_tlwh
(
self
):
"""Get position in format `(top left x, top left y, width, height)`."""
ret
=
self
.
mean
[:
4
].
copy
()
ret
[
2
]
*=
ret
[
3
]
ret
[:
2
]
-=
ret
[
2
:]
/
2
return
ret
def
to_tlbr
(
self
):
"""Get position in bounding box format `(min x, miny, max x, max y)`."""
ret
=
self
.
to_tlwh
()
ret
[
2
:]
=
ret
[:
2
]
+
ret
[
2
:]
return
ret
def
predict
(
self
,
kalman_filter
):
"""
Propagate the state distribution to the current time step using a Kalman
filter prediction step.
"""
self
.
mean
,
self
.
covariance
=
kalman_filter
.
predict
(
self
.
mean
,
self
.
covariance
)
self
.
age
+=
1
self
.
time_since_update
+=
1
def
update
(
self
,
kalman_filter
,
detection
):
"""
Perform Kalman filter measurement update step and update the associated
detection feature cache.
"""
self
.
mean
,
self
.
covariance
=
kalman_filter
.
update
(
self
.
mean
,
self
.
covariance
,
detection
.
to_xyah
())
self
.
features
.
append
(
detection
.
feature
)
self
.
hits
+=
1
self
.
time_since_update
=
0
if
self
.
state
==
TrackState
.
Tentative
and
self
.
hits
>=
self
.
_n_init
:
self
.
state
=
TrackState
.
Confirmed
def
mark_missed
(
self
):
"""Mark this track as missed (no association at the current time step).
"""
if
self
.
state
==
TrackState
.
Tentative
:
self
.
state
=
TrackState
.
Deleted
elif
self
.
time_since_update
>
self
.
_max_age
:
self
.
state
=
TrackState
.
Deleted
def
is_tentative
(
self
):
"""Returns True if this track is tentative (unconfirmed)."""
return
self
.
state
==
TrackState
.
Tentative
def
is_confirmed
(
self
):
"""Returns True if this track is confirmed."""
return
self
.
state
==
TrackState
.
Confirmed
def
is_deleted
(
self
):
"""Returns True if this track is dead and should be deleted."""
return
self
.
state
==
TrackState
.
Deleted
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/jde_tracker.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import
paddle
from
..matching
import
jde_matching
as
matching
from
.base_jde_tracker
import
TrackState
,
BaseTrack
,
STrack
from
.base_jde_tracker
import
joint_stracks
,
sub_stracks
,
remove_duplicate_stracks
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.utils.logger
import
setup_logger
logger
=
setup_logger
(
__name__
)
__all__
=
[
'FrozenJDETracker'
]
@
register
@
serializable
class
FrozenJDETracker
(
object
):
__inject__
=
[
'motion'
]
"""
JDE tracker
Args:
det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set -1
means no need to filter bboxes.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
r_tracked_thresh (float): linear assignment threshold of
tracked stracks and unmatched detections
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
motion (object): KalmanFilter instance
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def
__init__
(
self
,
det_thresh
=
0.3
,
track_buffer
=
30
,
min_box_area
=
200
,
vertical_ratio
=
1.6
,
tracked_thresh
=
0.7
,
r_tracked_thresh
=
0.5
,
unconfirmed_thresh
=
0.7
,
motion
=
'KalmanFilter'
,
conf_thres
=
0
,
metric_type
=
'euclidean'
):
self
.
det_thresh
=
det_thresh
self
.
track_buffer
=
track_buffer
self
.
min_box_area
=
min_box_area
self
.
vertical_ratio
=
vertical_ratio
self
.
tracked_thresh
=
tracked_thresh
self
.
r_tracked_thresh
=
r_tracked_thresh
self
.
unconfirmed_thresh
=
unconfirmed_thresh
self
.
motion
=
motion
self
.
conf_thres
=
conf_thres
self
.
metric_type
=
metric_type
self
.
frame_id
=
0
self
.
tracked_stracks
=
[]
self
.
lost_stracks
=
[]
self
.
removed_stracks
=
[]
self
.
max_time_lost
=
0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def
update
(
self
,
pred_dets
,
pred_embs
):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles
lost, removed, refound and active tracklets.
Args:
pred_dets (Tensor): Detection results of the image, shape is [N, 5].
pred_embs (Tensor): Embedding results of the image, shape is [N, 512].
Return:
output_stracks (list): The list contains information regarding the
online_tracklets for the recieved image tensor.
"""
self
.
frame_id
+=
1
activated_starcks
=
[]
# for storing active tracks, for the current frame
refind_stracks
=
[]
# Lost Tracks whose detections are obtained in the current frame
lost_stracks
=
[]
# The tracks which are not obtained in the current frame but are not
# removed. (Lost for some time lesser than the threshold for removing)
removed_stracks
=
[]
remain_inds
=
paddle
.
nonzero
(
pred_dets
[:,
4
]
>
self
.
conf_thres
)
if
remain_inds
.
shape
[
0
]
==
0
:
pred_dets
=
paddle
.
zeros
([
0
,
1
])
pred_embs
=
paddle
.
zeros
([
0
,
1
])
else
:
pred_dets
=
paddle
.
gather
(
pred_dets
,
remain_inds
)
pred_embs
=
paddle
.
gather
(
pred_embs
,
remain_inds
)
# Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]]
empty_pred
=
True
if
len
(
pred_dets
)
==
1
and
paddle
.
sum
(
pred_dets
)
==
0.0
else
False
""" Step 1: Network forward, get detections & embeddings"""
if
len
(
pred_dets
)
>
0
and
not
empty_pred
:
pred_dets
=
pred_dets
.
numpy
()
pred_embs
=
pred_embs
.
numpy
()
detections
=
[
STrack
(
STrack
.
tlbr_to_tlwh
(
tlbrs
[:
4
]),
tlbrs
[
4
],
f
,
30
)
for
(
tlbrs
,
f
)
in
zip
(
pred_dets
,
pred_embs
)
]
else
:
detections
=
[]
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed
=
[]
tracked_stracks
=
[]
# type: list[STrack]
for
track
in
self
.
tracked_stracks
:
if
not
track
.
is_activated
:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed
.
append
(
track
)
else
:
# Active tracks are added to the local list 'tracked_stracks'
tracked_stracks
.
append
(
track
)
""" Step 2: First association, with embedding"""
# Combining currently tracked_stracks and lost_stracks
strack_pool
=
joint_stracks
(
tracked_stracks
,
self
.
lost_stracks
)
# Predict the current location with KF
STrack
.
multi_predict
(
strack_pool
,
self
.
motion
)
dists
=
matching
.
embedding_distance
(
strack_pool
,
detections
,
metric
=
self
.
metric_type
)
dists
=
matching
.
fuse_motion
(
self
.
motion
,
dists
,
strack_pool
,
detections
)
# The dists is the list of distances of the detection with the tracks in strack_pool
matches
,
u_track
,
u_detection
=
matching
.
linear_assignment
(
dists
,
thresh
=
self
.
tracked_thresh
)
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
for
itracked
,
idet
in
matches
:
# itracked is the id of the track and idet is the detection
track
=
strack_pool
[
itracked
]
det
=
detections
[
idet
]
if
track
.
state
==
TrackState
.
Tracked
:
# If the track is active, add the detection to the track
track
.
update
(
detections
[
idet
],
self
.
frame_id
)
activated_starcks
.
append
(
track
)
else
:
# We have obtained a detection from a track which is not active,
# hence put the track in refind_stracks list
track
.
re_activate
(
det
,
self
.
frame_id
,
new_id
=
False
)
refind_stracks
.
append
(
track
)
# None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU"""
detections
=
[
detections
[
i
]
for
i
in
u_detection
]
# detections is now a list of the unmatched detections
r_tracked_stracks
=
[]
# This is container for stracks which were tracked till the previous
# frame but no detection was found for it in the current frame.
for
i
in
u_track
:
if
strack_pool
[
i
].
state
==
TrackState
.
Tracked
:
r_tracked_stracks
.
append
(
strack_pool
[
i
])
dists
=
matching
.
iou_distance
(
r_tracked_stracks
,
detections
)
matches
,
u_track
,
u_detection
=
matching
.
linear_assignment
(
dists
,
thresh
=
self
.
r_tracked_thresh
)
# matches is the list of detections which matched with corresponding
# tracks by IOU distance method.
for
itracked
,
idet
in
matches
:
track
=
r_tracked_stracks
[
itracked
]
det
=
detections
[
idet
]
if
track
.
state
==
TrackState
.
Tracked
:
track
.
update
(
det
,
self
.
frame_id
)
activated_starcks
.
append
(
track
)
else
:
track
.
re_activate
(
det
,
self
.
frame_id
,
new_id
=
False
)
refind_stracks
.
append
(
track
)
# Same process done for some unmatched detections, but now considering IOU_distance as measure
for
it
in
u_track
:
track
=
r_tracked_stracks
[
it
]
if
not
track
.
state
==
TrackState
.
Lost
:
track
.
mark_lost
()
lost_stracks
.
append
(
track
)
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections
=
[
detections
[
i
]
for
i
in
u_detection
]
dists
=
matching
.
iou_distance
(
unconfirmed
,
detections
)
matches
,
u_unconfirmed
,
u_detection
=
matching
.
linear_assignment
(
dists
,
thresh
=
self
.
unconfirmed_thresh
)
for
itracked
,
idet
in
matches
:
unconfirmed
[
itracked
].
update
(
detections
[
idet
],
self
.
frame_id
)
activated_starcks
.
append
(
unconfirmed
[
itracked
])
# The tracks which are yet not matched
for
it
in
u_unconfirmed
:
track
=
unconfirmed
[
it
]
track
.
mark_removed
()
removed_stracks
.
append
(
track
)
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
""" Step 4: Init new stracks"""
for
inew
in
u_detection
:
track
=
detections
[
inew
]
if
track
.
score
<
self
.
det_thresh
:
continue
track
.
activate
(
self
.
motion
,
self
.
frame_id
)
activated_starcks
.
append
(
track
)
""" Step 5: Update state"""
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
for
track
in
self
.
lost_stracks
:
if
self
.
frame_id
-
track
.
end_frame
>
self
.
max_time_lost
:
track
.
mark_removed
()
removed_stracks
.
append
(
track
)
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
self
.
tracked_stracks
=
[
t
for
t
in
self
.
tracked_stracks
if
t
.
state
==
TrackState
.
Tracked
]
self
.
tracked_stracks
=
joint_stracks
(
self
.
tracked_stracks
,
activated_starcks
)
self
.
tracked_stracks
=
joint_stracks
(
self
.
tracked_stracks
,
refind_stracks
)
self
.
lost_stracks
=
sub_stracks
(
self
.
lost_stracks
,
self
.
tracked_stracks
)
self
.
lost_stracks
.
extend
(
lost_stracks
)
self
.
lost_stracks
=
sub_stracks
(
self
.
lost_stracks
,
self
.
removed_stracks
)
self
.
removed_stracks
.
extend
(
removed_stracks
)
self
.
tracked_stracks
,
self
.
lost_stracks
=
remove_duplicate_stracks
(
self
.
tracked_stracks
,
self
.
lost_stracks
)
# get scores of lost tracks
output_stracks
=
[
track
for
track
in
self
.
tracked_stracks
if
track
.
is_activated
]
logger
.
debug
(
'===========Frame {}=========='
.
format
(
self
.
frame_id
))
logger
.
debug
(
'Activated: {}'
.
format
([
track
.
track_id
for
track
in
activated_starcks
]))
logger
.
debug
(
'Refind: {}'
.
format
([
track
.
track_id
for
track
in
refind_stracks
]))
logger
.
debug
(
'Lost: {}'
.
format
([
track
.
track_id
for
track
in
lost_stracks
]))
logger
.
debug
(
'Removed: {}'
.
format
([
track
.
track_id
for
track
in
removed_stracks
]))
return
output_stracks
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/utils.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
time
import
paddle
import
numpy
as
np
__all__
=
[
'Timer'
,
'Detection'
,
'load_det_results'
,
'preprocess_reid'
,
'get_crops'
,
'clip_box'
,
'scale_coords'
,
]
class
Timer
(
object
):
"""
This class used to compute and print the current FPS while evaling.
"""
def
__init__
(
self
):
self
.
total_time
=
0.
self
.
calls
=
0
self
.
start_time
=
0.
self
.
diff
=
0.
self
.
average_time
=
0.
self
.
duration
=
0.
def
tic
(
self
):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self
.
start_time
=
time
.
time
()
def
toc
(
self
,
average
=
True
):
self
.
diff
=
time
.
time
()
-
self
.
start_time
self
.
total_time
+=
self
.
diff
self
.
calls
+=
1
self
.
average_time
=
self
.
total_time
/
self
.
calls
if
average
:
self
.
duration
=
self
.
average_time
else
:
self
.
duration
=
self
.
diff
return
self
.
duration
def
clear
(
self
):
self
.
total_time
=
0.
self
.
calls
=
0
self
.
start_time
=
0.
self
.
diff
=
0.
self
.
average_time
=
0.
self
.
duration
=
0.
class
Detection
(
object
):
"""
This class represents a bounding box detection in a single image.
Args:
tlwh (ndarray): Bounding box in format `(top left x, top left y,
width, height)`.
confidence (ndarray): Detector confidence score.
feature (Tensor): A feature vector that describes the object
contained in this image.
"""
def
__init__
(
self
,
tlwh
,
confidence
,
feature
):
self
.
tlwh
=
np
.
asarray
(
tlwh
,
dtype
=
np
.
float32
)
self
.
confidence
=
np
.
asarray
(
confidence
,
dtype
=
np
.
float32
)
self
.
feature
=
feature
def
to_tlbr
(
self
):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret
=
self
.
tlwh
.
copy
()
ret
[
2
:]
+=
ret
[:
2
]
return
ret
def
to_xyah
(
self
):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret
=
self
.
tlwh
.
copy
()
ret
[:
2
]
+=
ret
[
2
:]
/
2
ret
[
2
]
/=
ret
[
3
]
return
ret
def
load_det_results
(
det_file
,
num_frames
):
assert
os
.
path
.
exists
(
det_file
)
and
os
.
path
.
isfile
(
det_file
),
\
'Error: det_file: {} not exist or not a file.'
.
format
(
det_file
)
labels
=
np
.
loadtxt
(
det_file
,
dtype
=
'float32'
,
delimiter
=
','
)
results_list
=
[]
for
frame_i
in
range
(
0
,
num_frames
):
results
=
{
'bbox'
:
[],
'score'
:
[]}
lables_with_frame
=
labels
[
labels
[:,
0
]
==
frame_i
+
1
]
for
l
in
lables_with_frame
:
results
[
'bbox'
].
append
(
l
[
1
:
5
])
results
[
'score'
].
append
(
l
[
5
])
results_list
.
append
(
results
)
return
results_list
def
scale_coords
(
coords
,
input_shape
,
im_shape
,
scale_factor
):
im_shape
=
im_shape
.
numpy
()[
0
]
ratio
=
scale_factor
[
0
][
0
]
pad_w
=
(
input_shape
[
1
]
-
int
(
im_shape
[
1
]))
/
2
pad_h
=
(
input_shape
[
0
]
-
int
(
im_shape
[
0
]))
/
2
coords
=
paddle
.
cast
(
coords
,
'float32'
)
coords
[:,
0
::
2
]
-=
pad_w
coords
[:,
1
::
2
]
-=
pad_h
coords
[:,
0
:
4
]
/=
ratio
coords
[:,
:
4
]
=
paddle
.
clip
(
coords
[:,
:
4
],
min
=
0
,
max
=
coords
[:,
:
4
].
max
())
return
coords
.
round
()
def
clip_box
(
xyxy
,
input_shape
,
im_shape
,
scale_factor
):
im_shape
=
im_shape
.
numpy
()[
0
]
ratio
=
scale_factor
.
numpy
()[
0
][
0
]
img0_shape
=
[
int
(
im_shape
[
0
]
/
ratio
),
int
(
im_shape
[
1
]
/
ratio
)]
xyxy
[:,
0
::
2
]
=
paddle
.
clip
(
xyxy
[:,
0
::
2
],
min
=
0
,
max
=
img0_shape
[
1
])
xyxy
[:,
1
::
2
]
=
paddle
.
clip
(
xyxy
[:,
1
::
2
],
min
=
0
,
max
=
img0_shape
[
0
])
return
xyxy
def
get_crops
(
xyxy
,
ori_img
,
pred_scores
,
w
,
h
):
crops
=
[]
keep_scores
=
[]
xyxy
=
xyxy
.
numpy
().
astype
(
np
.
int64
)
ori_img
=
ori_img
.
numpy
()
ori_img
=
np
.
squeeze
(
ori_img
,
axis
=
0
).
transpose
(
1
,
0
,
2
)
pred_scores
=
pred_scores
.
numpy
()
for
i
,
bbox
in
enumerate
(
xyxy
):
if
bbox
[
2
]
<=
bbox
[
0
]
or
bbox
[
3
]
<=
bbox
[
1
]:
continue
crop
=
ori_img
[
bbox
[
0
]:
bbox
[
2
],
bbox
[
1
]:
bbox
[
3
],
:]
crops
.
append
(
crop
)
keep_scores
.
append
(
pred_scores
[
i
])
if
len
(
crops
)
==
0
:
return
[],
[]
crops
=
preprocess_reid
(
crops
,
w
,
h
)
return
crops
,
keep_scores
def
preprocess_reid
(
imgs
,
w
=
64
,
h
=
192
,
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]):
im_batch
=
[]
for
img
in
imgs
:
img
=
cv2
.
resize
(
img
,
(
w
,
h
))
img
=
img
[:,
:,
::
-
1
].
astype
(
'float32'
).
transpose
((
2
,
0
,
1
))
/
255
img_mean
=
np
.
array
(
mean
).
reshape
((
3
,
1
,
1
))
img_std
=
np
.
array
(
std
).
reshape
((
3
,
1
,
1
))
img
-=
img_mean
img
/=
img_std
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
im_batch
.
append
(
img
)
im_batch
=
np
.
concatenate
(
im_batch
,
0
)
return
im_batch
modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/visualization.py
0 → 100644
浏览文件 @
d53f6412
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
cv2
import
numpy
as
np
def
tlwhs_to_tlbrs
(
tlwhs
):
tlbrs
=
np
.
copy
(
tlwhs
)
if
len
(
tlbrs
)
==
0
:
return
tlbrs
tlbrs
[:,
2
]
+=
tlwhs
[:,
0
]
tlbrs
[:,
3
]
+=
tlwhs
[:,
1
]
return
tlbrs
def
get_color
(
idx
):
idx
=
idx
*
3
color
=
((
37
*
idx
)
%
255
,
(
17
*
idx
)
%
255
,
(
29
*
idx
)
%
255
)
return
color
def
resize_image
(
image
,
max_size
=
800
):
if
max
(
image
.
shape
[:
2
])
>
max_size
:
scale
=
float
(
max_size
)
/
max
(
image
.
shape
[:
2
])
image
=
cv2
.
resize
(
image
,
None
,
fx
=
scale
,
fy
=
scale
)
return
image
def
plot_tracking
(
image
,
tlwhs
,
obj_ids
,
scores
=
None
,
frame_id
=
0
,
fps
=
0.
,
ids2
=
None
):
im
=
np
.
ascontiguousarray
(
np
.
copy
(
image
))
im_h
,
im_w
=
im
.
shape
[:
2
]
top_view
=
np
.
zeros
([
im_w
,
im_w
,
3
],
dtype
=
np
.
uint8
)
+
255
text_scale
=
max
(
1
,
image
.
shape
[
1
]
/
1600.
)
text_thickness
=
2
line_thickness
=
max
(
1
,
int
(
image
.
shape
[
1
]
/
500.
))
radius
=
max
(
5
,
int
(
im_w
/
140.
))
cv2
.
putText
(
im
,
'frame: %d fps: %.2f num: %d'
%
(
frame_id
,
fps
,
len
(
tlwhs
)),
(
0
,
int
(
15
*
text_scale
)),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
0
,
255
),
thickness
=
2
)
for
i
,
tlwh
in
enumerate
(
tlwhs
):
x1
,
y1
,
w
,
h
=
tlwh
intbox
=
tuple
(
map
(
int
,
(
x1
,
y1
,
x1
+
w
,
y1
+
h
)))
obj_id
=
int
(
obj_ids
[
i
])
id_text
=
'{}'
.
format
(
int
(
obj_id
))
if
ids2
is
not
None
:
id_text
=
id_text
+
', {}'
.
format
(
int
(
ids2
[
i
]))
_line_thickness
=
1
if
obj_id
<=
0
else
line_thickness
color
=
get_color
(
abs
(
obj_id
))
cv2
.
rectangle
(
im
,
intbox
[
0
:
2
],
intbox
[
2
:
4
],
color
=
color
,
thickness
=
line_thickness
)
cv2
.
putText
(
im
,
id_text
,
(
intbox
[
0
],
intbox
[
1
]
+
10
),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
0
,
255
),
thickness
=
text_thickness
)
if
scores
is
not
None
:
text
=
'{:.2f}'
.
format
(
float
(
scores
[
i
]))
cv2
.
putText
(
im
,
text
,
(
intbox
[
0
],
intbox
[
1
]
-
10
),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
255
,
255
),
thickness
=
text_thickness
)
return
im
def
plot_trajectory
(
image
,
tlwhs
,
track_ids
):
image
=
image
.
copy
()
for
one_tlwhs
,
track_id
in
zip
(
tlwhs
,
track_ids
):
color
=
get_color
(
int
(
track_id
))
for
tlwh
in
one_tlwhs
:
x1
,
y1
,
w
,
h
=
tuple
(
map
(
int
,
tlwh
))
cv2
.
circle
(
image
,
(
int
(
x1
+
0.5
*
w
),
int
(
y1
+
h
)),
2
,
color
,
thickness
=
2
)
return
image
def
plot_detections
(
image
,
tlbrs
,
scores
=
None
,
color
=
(
255
,
0
,
0
),
ids
=
None
):
im
=
np
.
copy
(
image
)
text_scale
=
max
(
1
,
image
.
shape
[
1
]
/
800.
)
thickness
=
2
if
text_scale
>
1.3
else
1
for
i
,
det
in
enumerate
(
tlbrs
):
x1
,
y1
,
x2
,
y2
=
np
.
asarray
(
det
[:
4
],
dtype
=
np
.
int
)
if
len
(
det
)
>=
7
:
label
=
'det'
if
det
[
5
]
>
0
else
'trk'
if
ids
is
not
None
:
text
=
'{}# {:.2f}: {:d}'
.
format
(
label
,
det
[
6
],
ids
[
i
])
cv2
.
putText
(
im
,
text
,
(
x1
,
y1
+
30
),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
255
,
255
),
thickness
=
thickness
)
else
:
text
=
'{}# {:.2f}'
.
format
(
label
,
det
[
6
])
if
scores
is
not
None
:
text
=
'{:.2f}'
.
format
(
scores
[
i
])
cv2
.
putText
(
im
,
text
,
(
x1
,
y1
+
30
),
cv2
.
FONT_HERSHEY_PLAIN
,
text_scale
,
(
0
,
255
,
255
),
thickness
=
thickness
)
cv2
.
rectangle
(
im
,
(
x1
,
y1
),
(
x2
,
y2
),
color
,
2
)
return
im
modules/video/multiple_object_tracking/jde_darknet53/tracker.py
浏览文件 @
d53f6412
...
...
@@ -16,18 +16,19 @@ import cv2
import
glob
import
paddle
import
numpy
as
np
import
collections
from
ppdet.core.workspace
import
create
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
from
ppdet.modeling.mot.utils
import
Detection
,
get_crops
,
scale_coords
,
clip_box
from
ppdet.modeling.mot.utils
import
Timer
,
load_det_results
from
ppdet.modeling.mot
import
visualization
as
mot_vis
from
ppdet.metrics
import
Metric
,
MOTMetric
,
KITTIMOTMetric
import
ppdet.utils.stats
as
stats
from
ppdet.engine.callbacks
import
Callback
,
ComposeCallback
from
ppdet.core.workspace
import
create
from
ppdet.utils.logger
import
setup_logger
from
.dataset
import
MOTVideoStream
,
MOTImageStream
from
.modeling.mot.utils
import
Detection
,
get_crops
,
scale_coords
,
clip_box
from
.modeling.mot
import
visualization
as
mot_vis
from
.utils
import
Timer
logger
=
setup_logger
(
__name__
)
...
...
@@ -70,7 +71,6 @@ class StreamTracker(object):
timer
.
tic
()
pred_dets
,
pred_embs
=
self
.
model
(
data
)
online_targets
=
self
.
model
.
tracker
.
update
(
pred_dets
,
pred_embs
)
online_tlwhs
,
online_ids
=
[],
[]
online_scores
=
[]
for
t
in
online_targets
:
...
...
@@ -109,7 +109,6 @@ class StreamTracker(object):
with
paddle
.
no_grad
():
pred_dets
,
pred_embs
=
self
.
model
(
data
)
online_targets
=
self
.
model
.
tracker
.
update
(
pred_dets
,
pred_embs
)
online_tlwhs
,
online_ids
=
[],
[]
online_scores
=
[]
for
t
in
online_targets
:
...
...
modules/video/multiple_object_tracking/jde_darknet53/utils.py
0 → 100644
浏览文件 @
d53f6412
import
time
class
Timer
(
object
):
"""
This class used to compute and print the current FPS while evaling.
"""
def
__init__
(
self
):
self
.
total_time
=
0.
self
.
calls
=
0
self
.
start_time
=
0.
self
.
diff
=
0.
self
.
average_time
=
0.
self
.
duration
=
0.
def
tic
(
self
):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self
.
start_time
=
time
.
time
()
def
toc
(
self
,
average
=
True
):
self
.
diff
=
time
.
time
()
-
self
.
start_time
self
.
total_time
+=
self
.
diff
self
.
calls
+=
1
self
.
average_time
=
self
.
total_time
/
self
.
calls
if
average
:
self
.
duration
=
self
.
average_time
else
:
self
.
duration
=
self
.
diff
return
self
.
duration
def
clear
(
self
):
self
.
total_time
=
0.
self
.
calls
=
0
self
.
start_time
=
0.
self
.
diff
=
0.
self
.
average_time
=
0.
self
.
duration
=
0.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录