Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
b2f3ad7c
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b2f3ad7c
编写于
11月 06, 2021
作者:
F
Feng Ni
提交者:
GitHub
11月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MOT] refine deepsort, fix jde (#4490)
上级
d4a7c9e0
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
41 addition
and
238 deletion
+41
-238
deploy/python/mot_sde_infer.py
deploy/python/mot_sde_infer.py
+14
-29
deploy/python/tracker/__init__.py
deploy/python/tracker/__init__.py
+0
-17
deploy/python/tracker/deepsort_tracker.py
deploy/python/tracker/deepsort_tracker.py
+0
-178
ppdet/engine/tracker.py
ppdet/engine/tracker.py
+3
-2
ppdet/modeling/architectures/jde.py
ppdet/modeling/architectures/jde.py
+1
-1
ppdet/modeling/mot/tracker/deepsort_tracker.py
ppdet/modeling/mot/tracker/deepsort_tracker.py
+9
-8
ppdet/modeling/reid/pplcnet_embedding.py
ppdet/modeling/reid/pplcnet_embedding.py
+14
-3
未找到文件。
deploy/python/mot_sde_infer.py
浏览文件 @
b2f3ad7c
...
@@ -17,19 +17,21 @@ import time
...
@@ -17,19 +17,21 @@ import time
import
yaml
import
yaml
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
import
paddle
from
collections
import
defaultdict
from
benchmark_utils
import
PaddleInferBenchmark
from
preprocess
import
preprocess
from
tracker
import
DeepSORTTracker
from
ppdet.modeling.mot
import
visualization
as
mot_vis
from
ppdet.modeling.mot.utils
import
MOTTimer
import
paddle
from
paddle.inference
import
Config
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddle.inference
import
create_predictor
from
preprocess
import
preprocess
from
utils
import
argsparser
,
Timer
,
get_current_memory_mb
from
utils
import
argsparser
,
Timer
,
get_current_memory_mb
from
infer
import
get_test_images
,
print_arguments
,
PredictConfig
,
Detector
from
infer
import
Detector
,
get_test_images
,
print_arguments
,
PredictConfig
from
mot_jde_infer
import
write_mot_results
from
infer
import
load_predictor
from
infer
import
load_predictor
from
benchmark_utils
import
PaddleInferBenchmark
from
ppdet.modeling.mot.tracker
import
DeepSORTTracker
from
ppdet.modeling.mot.visualization
import
plot_tracking
from
ppdet.modeling.mot.utils
import
MOTTimer
,
write_mot_results
# Global dictionary
# Global dictionary
MOT_SUPPORT_MODELS
=
{
'DeepSORT'
}
MOT_SUPPORT_MODELS
=
{
'DeepSORT'
}
...
@@ -362,7 +364,7 @@ def predict_image(detector, reid_model, image_list):
...
@@ -362,7 +364,7 @@ def predict_image(detector, reid_model, image_list):
else
:
else
:
online_tlwhs
,
online_scores
,
online_ids
=
reid_model
.
predict
(
online_tlwhs
,
online_scores
,
online_ids
=
reid_model
.
predict
(
crops
,
pred_dets
)
crops
,
pred_dets
)
online_im
=
mot_vis
.
plot_tracking
(
online_im
=
plot_tracking
(
frame
,
online_tlwhs
,
online_ids
,
online_scores
,
frame_id
=
i
)
frame
,
online_tlwhs
,
online_ids
,
online_scores
,
frame_id
=
i
)
if
FLAGS
.
save_images
:
if
FLAGS
.
save_images
:
...
@@ -396,7 +398,7 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -396,7 +398,7 @@ def predict_video(detector, reid_model, camera_id):
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
frame_id
=
0
frame_id
=
0
timer
=
MOTTimer
()
timer
=
MOTTimer
()
results
=
[]
results
=
defaultdict
(
list
)
while
(
1
):
while
(
1
):
ret
,
frame
=
capture
.
read
()
ret
,
frame
=
capture
.
read
()
if
not
ret
:
if
not
ret
:
...
@@ -415,12 +417,12 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -415,12 +417,12 @@ def predict_video(detector, reid_model, camera_id):
crops
=
reid_model
.
get_crops
(
pred_xyxys
,
frame
)
crops
=
reid_model
.
get_crops
(
pred_xyxys
,
frame
)
online_tlwhs
,
online_scores
,
online_ids
=
reid_model
.
predict
(
online_tlwhs
,
online_scores
,
online_ids
=
reid_model
.
predict
(
crops
,
pred_dets
)
crops
,
pred_dets
)
results
.
append
(
results
[
0
]
.
append
(
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
))
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
))
timer
.
toc
()
timer
.
toc
()
fps
=
1.
/
timer
.
average_time
fps
=
1.
/
timer
.
average_time
im
=
mot_vis
.
plot_tracking
(
im
=
plot_tracking
(
frame
,
frame
,
online_tlwhs
,
online_tlwhs
,
online_ids
,
online_ids
,
...
@@ -437,23 +439,6 @@ def predict_video(detector, reid_model, camera_id):
...
@@ -437,23 +439,6 @@ def predict_video(detector, reid_model, camera_id):
else
:
else
:
writer
.
write
(
im
)
writer
.
write
(
im
)
if
FLAGS
.
save_mot_txt_per_img
:
save_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
.
split
(
'.'
)[
-
2
])
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
result_filename
=
os
.
path
.
join
(
save_dir
,
'{:05d}.txt'
.
format
(
frame_id
))
# First few frames, the model may have no tracking results but have
# detection results,use the detection results instead, and set id -1.
if
results
[
-
1
][
2
]
==
[]:
tlwhs
=
[
tlwh
for
tlwh
in
pred_dets
[:,
:
4
]]
scores
=
[
score
[
0
]
for
score
in
pred_dets
[:,
4
:
5
]]
ids
=
[
-
1
]
*
len
(
tlwhs
)
result
=
(
frame_id
+
1
,
tlwhs
,
scores
,
ids
)
else
:
result
=
results
[
-
1
]
write_mot_results
(
result_filename
,
[
result
])
frame_id
+=
1
frame_id
+=
1
print
(
'detect frame:%d'
%
(
frame_id
))
print
(
'detect frame:%d'
%
(
frame_id
))
...
...
deploy/python/tracker/__init__.py
已删除
100644 → 0
浏览文件 @
d4a7c9e0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
deepsort_tracker
from
.deepsort_tracker
import
*
deploy/python/tracker/deepsort_tracker.py
已删除
100644 → 0
浏览文件 @
d4a7c9e0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/tracker.py
"""
import
numpy
as
np
from
ppdet.modeling.mot.motion
import
KalmanFilter
from
ppdet.modeling.mot.matching.deepsort_matching
import
NearestNeighborDistanceMetric
from
ppdet.modeling.mot.matching.deepsort_matching
import
iou_cost
,
min_cost_matching
,
matching_cascade
,
gate_cost_matrix
from
ppdet.modeling.mot.tracker.base_sde_tracker
import
Track
from
ppdet.modeling.mot.utils
import
Detection
__all__
=
[
'DeepSORTTracker'
]
class
DeepSORTTracker
(
object
):
"""
DeepSORT tracker
Args:
input_size (list): input feature map size to reid model, [h, w] format,
[64, 192] as default.
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set <=0
means no need to filter bboxes.
budget (int): If not None, fix samples per class to at most this number.
Removes the oldest samples when the budget is reached.
max_age (int): maximum number of missed misses before a track is deleted
n_init (float): Number of frames that a track remains in initialization
phase. Number of consecutive detections before the track is confirmed.
The track state is set to `Deleted` if a miss occurs within the first
`n_init` frames.
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
matching_threshold (float): samples with larger distance are
considered an invalid match.
max_iou_distance (float): max iou distance threshold
motion (object): KalmanFilter instance
"""
def
__init__
(
self
,
input_size
=
[
64
,
192
],
min_box_area
=
0
,
vertical_ratio
=-
1
,
budget
=
100
,
max_age
=
70
,
n_init
=
3
,
metric_type
=
'cosine'
,
matching_threshold
=
0.2
,
max_iou_distance
=
0.9
,
motion
=
'KalmanFilter'
):
self
.
input_size
=
input_size
self
.
min_box_area
=
min_box_area
self
.
vertical_ratio
=
vertical_ratio
self
.
max_age
=
max_age
self
.
n_init
=
n_init
self
.
metric
=
NearestNeighborDistanceMetric
(
metric_type
,
matching_threshold
,
budget
)
self
.
max_iou_distance
=
max_iou_distance
self
.
motion
=
KalmanFilter
()
self
.
tracks
=
[]
self
.
_next_id
=
1
def
predict
(
self
):
"""
Propagate track state distributions one time step forward.
This function should be called once every time step, before `update`.
"""
for
track
in
self
.
tracks
:
track
.
predict
(
self
.
motion
)
def
update
(
self
,
pred_dets
,
pred_embs
):
"""
pred_dets (Tensor): Detection results of the image, shape is [N, 6].
pred_embs (Tensor): Embedding results of the image, shape is [N, 128],
usually pred_embs.shape[1] can be a multiple of 128, in PCB
Pyramidal model is 128*21.
"""
pred_tlwhs
=
pred_dets
[:,
:
4
]
pred_scores
=
pred_dets
[:,
4
:
5
]
pred_cls_ids
=
pred_dets
[:,
5
:]
detections
=
[
Detection
(
tlwh
,
score
,
feat
,
cls_id
)
for
tlwh
,
score
,
feat
,
cls_id
in
zip
(
pred_tlwhs
,
pred_scores
,
pred_embs
,
pred_cls_ids
)
]
# Run matching cascade.
matches
,
unmatched_tracks
,
unmatched_detections
=
\
self
.
_match
(
detections
)
# Update track set.
for
track_idx
,
detection_idx
in
matches
:
self
.
tracks
[
track_idx
].
update
(
self
.
motion
,
detections
[
detection_idx
])
for
track_idx
in
unmatched_tracks
:
self
.
tracks
[
track_idx
].
mark_missed
()
for
detection_idx
in
unmatched_detections
:
self
.
_initiate_track
(
detections
[
detection_idx
])
self
.
tracks
=
[
t
for
t
in
self
.
tracks
if
not
t
.
is_deleted
()]
# Update distance metric.
active_targets
=
[
t
.
track_id
for
t
in
self
.
tracks
if
t
.
is_confirmed
()]
features
,
targets
=
[],
[]
for
track
in
self
.
tracks
:
if
not
track
.
is_confirmed
():
continue
features
+=
track
.
features
targets
+=
[
track
.
track_id
for
_
in
track
.
features
]
track
.
features
=
[]
self
.
metric
.
partial_fit
(
np
.
asarray
(
features
),
np
.
asarray
(
targets
),
active_targets
)
output_stracks
=
self
.
tracks
return
output_stracks
def
_match
(
self
,
detections
):
def
gated_metric
(
tracks
,
dets
,
track_indices
,
detection_indices
):
features
=
np
.
array
([
dets
[
i
].
feature
for
i
in
detection_indices
])
targets
=
np
.
array
([
tracks
[
i
].
track_id
for
i
in
track_indices
])
cost_matrix
=
self
.
metric
.
distance
(
features
,
targets
)
cost_matrix
=
gate_cost_matrix
(
self
.
motion
,
cost_matrix
,
tracks
,
dets
,
track_indices
,
detection_indices
)
return
cost_matrix
# Split track set into confirmed and unconfirmed tracks.
confirmed_tracks
=
[
i
for
i
,
t
in
enumerate
(
self
.
tracks
)
if
t
.
is_confirmed
()
]
unconfirmed_tracks
=
[
i
for
i
,
t
in
enumerate
(
self
.
tracks
)
if
not
t
.
is_confirmed
()
]
# Associate confirmed tracks using appearance features.
matches_a
,
unmatched_tracks_a
,
unmatched_detections
=
\
matching_cascade
(
gated_metric
,
self
.
metric
.
matching_threshold
,
self
.
max_age
,
self
.
tracks
,
detections
,
confirmed_tracks
)
# Associate remaining tracks together with unconfirmed tracks using IOU.
iou_track_candidates
=
unconfirmed_tracks
+
[
k
for
k
in
unmatched_tracks_a
if
self
.
tracks
[
k
].
time_since_update
==
1
]
unmatched_tracks_a
=
[
k
for
k
in
unmatched_tracks_a
if
self
.
tracks
[
k
].
time_since_update
!=
1
]
matches_b
,
unmatched_tracks_b
,
unmatched_detections
=
\
min_cost_matching
(
iou_cost
,
self
.
max_iou_distance
,
self
.
tracks
,
detections
,
iou_track_candidates
,
unmatched_detections
)
matches
=
matches_a
+
matches_b
unmatched_tracks
=
list
(
set
(
unmatched_tracks_a
+
unmatched_tracks_b
))
return
matches
,
unmatched_tracks
,
unmatched_detections
def
_initiate_track
(
self
,
detection
):
mean
,
covariance
=
self
.
motion
.
initiate
(
detection
.
to_xyah
())
self
.
tracks
.
append
(
Track
(
mean
,
covariance
,
self
.
_next_id
,
self
.
n_init
,
self
.
max_age
,
detection
.
cls_id
,
detection
.
score
,
detection
.
feature
))
self
.
_next_id
+=
1
ppdet/engine/tracker.py
浏览文件 @
b2f3ad7c
...
@@ -184,7 +184,7 @@ class Tracker(object):
...
@@ -184,7 +184,7 @@ class Tracker(object):
use_detector
=
False
if
not
self
.
model
.
detector
else
True
use_detector
=
False
if
not
self
.
model
.
detector
else
True
timer
=
MOTTimer
()
timer
=
MOTTimer
()
results
=
[]
results
=
defaultdict
(
list
)
frame_id
=
0
frame_id
=
0
self
.
status
[
'mode'
]
=
'track'
self
.
status
[
'mode'
]
=
'track'
self
.
model
.
eval
()
self
.
model
.
eval
()
...
@@ -269,6 +269,7 @@ class Tracker(object):
...
@@ -269,6 +269,7 @@ class Tracker(object):
data
.
update
({
'crops'
:
crops
})
data
.
update
({
'crops'
:
crops
})
pred_embs
=
self
.
model
(
data
)
pred_embs
=
self
.
model
(
data
)
pred_dets
,
pred_embs
=
pred_dets
.
numpy
(),
pred_embs
.
numpy
()
tracker
.
predict
()
tracker
.
predict
()
online_targets
=
tracker
.
update
(
pred_dets
,
pred_embs
)
online_targets
=
tracker
.
update
(
pred_dets
,
pred_embs
)
...
@@ -291,7 +292,7 @@ class Tracker(object):
...
@@ -291,7 +292,7 @@ class Tracker(object):
timer
.
toc
()
timer
.
toc
()
# save results
# save results
results
.
append
(
results
[
0
]
.
append
(
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
))
(
frame_id
+
1
,
online_tlwhs
,
online_scores
,
online_ids
))
save_vis_results
(
data
,
frame_id
,
online_ids
,
online_tlwhs
,
save_vis_results
(
data
,
frame_id
,
online_ids
,
online_tlwhs
,
online_scores
,
timer
.
average_time
,
show_image
,
online_scores
,
timer
.
average_time
,
show_image
,
...
...
ppdet/modeling/architectures/jde.py
浏览文件 @
b2f3ad7c
...
@@ -105,7 +105,7 @@ class JDE(BaseArch):
...
@@ -105,7 +105,7 @@ class JDE(BaseArch):
nms_keep_idx
=
det_outs
[
'nms_keep_idx'
]
nms_keep_idx
=
det_outs
[
'nms_keep_idx'
]
pred_dets
=
paddle
.
concat
((
bbox
[:,
2
:],
bbox
[:,
1
:
2
]),
axis
=
1
)
pred_dets
=
paddle
.
concat
((
bbox
[:,
2
:],
bbox
[:,
1
:
2
]
,
bbox
[:,
0
:
1
]
),
axis
=
1
)
emb_valid
=
paddle
.
gather_nd
(
emb_outs
,
boxes_idx
)
emb_valid
=
paddle
.
gather_nd
(
emb_outs
,
boxes_idx
)
pred_embs
=
paddle
.
gather_nd
(
emb_valid
,
nms_keep_idx
)
pred_embs
=
paddle
.
gather_nd
(
emb_valid
,
nms_keep_idx
)
...
...
ppdet/modeling/mot/tracker/deepsort_tracker.py
浏览文件 @
b2f3ad7c
...
@@ -17,6 +17,7 @@ This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_so
...
@@ -17,6 +17,7 @@ This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_so
import
numpy
as
np
import
numpy
as
np
from
..motion
import
KalmanFilter
from
..matching.deepsort_matching
import
NearestNeighborDistanceMetric
from
..matching.deepsort_matching
import
NearestNeighborDistanceMetric
from
..matching.deepsort_matching
import
iou_cost
,
min_cost_matching
,
matching_cascade
,
gate_cost_matrix
from
..matching.deepsort_matching
import
iou_cost
,
min_cost_matching
,
matching_cascade
,
gate_cost_matrix
from
.base_sde_tracker
import
Track
from
.base_sde_tracker
import
Track
...
@@ -32,7 +33,6 @@ __all__ = ['DeepSORTTracker']
...
@@ -32,7 +33,6 @@ __all__ = ['DeepSORTTracker']
@
register
@
register
@
serializable
@
serializable
class
DeepSORTTracker
(
object
):
class
DeepSORTTracker
(
object
):
__inject__
=
[
'motion'
]
"""
"""
DeepSORT tracker
DeepSORT tracker
...
@@ -77,7 +77,8 @@ class DeepSORTTracker(object):
...
@@ -77,7 +77,8 @@ class DeepSORTTracker(object):
self
.
metric
=
NearestNeighborDistanceMetric
(
metric_type
,
self
.
metric
=
NearestNeighborDistanceMetric
(
metric_type
,
matching_threshold
,
budget
)
matching_threshold
,
budget
)
self
.
max_iou_distance
=
max_iou_distance
self
.
max_iou_distance
=
max_iou_distance
self
.
motion
=
motion
if
motion
==
'KalmanFilter'
:
self
.
motion
=
KalmanFilter
()
self
.
tracks
=
[]
self
.
tracks
=
[]
self
.
_next_id
=
1
self
.
_next_id
=
1
...
@@ -94,14 +95,14 @@ class DeepSORTTracker(object):
...
@@ -94,14 +95,14 @@ class DeepSORTTracker(object):
"""
"""
Perform measurement update and track management.
Perform measurement update and track management.
Args:
Args:
pred_dets (
Tensor): Detection results of the image, shape is [N, 6].
pred_dets (
np.array): Detection results of the image, the shape is
pred_embs (Tensor): Embedding results of the image, shape is [N, 128],
[N, 6], means 'x0, y0, x1, y1, score, cls_id'.
usually pred_embs.shape[1] can be a multiple of 128, in PCB
pred_embs (np.array): Embedding results of the image, the shape is
Pyramidal model is 128*21
.
[N, 128], usually pred_embs.shape[1] is a multiple of 128
.
"""
"""
pred_tlwhs
=
pred_dets
[:,
:
4
]
pred_tlwhs
=
pred_dets
[:,
:
4
]
pred_scores
=
pred_dets
[:,
4
:
5
]
.
squeeze
(
1
)
pred_scores
=
pred_dets
[:,
4
:
5
]
pred_cls_ids
=
pred_dets
[:,
5
:]
.
squeeze
(
1
)
pred_cls_ids
=
pred_dets
[:,
5
:]
detections
=
[
detections
=
[
Detection
(
tlwh
,
score
,
feat
,
cls_id
)
Detection
(
tlwh
,
score
,
feat
,
cls_id
)
...
...
ppdet/modeling/reid/pplcnet_embedding.py
浏览文件 @
b2f3ad7c
...
@@ -21,9 +21,9 @@ import paddle.nn as nn
...
@@ -21,9 +21,9 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
Normal
,
Constant
from
paddle.nn.initializer
import
Normal
,
Constant
from
paddle
import
ParamAttr
from
paddle
import
ParamAttr
from
paddle.nn
import
AdaptiveAvgPool2D
,
BatchNorm
,
Conv2D
,
Dropout
,
Linear
from
paddle.nn
import
AdaptiveAvgPool2D
,
BatchNorm
,
Conv2D
,
Linear
from
paddle.regularizer
import
L2Decay
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
KaimingNormal
from
paddle.nn.initializer
import
KaimingNormal
,
XavierNormal
from
ppdet.core.workspace
import
register
from
ppdet.core.workspace
import
register
__all__
=
[
'PPLCNetEmbedding'
]
__all__
=
[
'PPLCNetEmbedding'
]
...
@@ -250,6 +250,17 @@ class PPLCNet(nn.Layer):
...
@@ -250,6 +250,17 @@ class PPLCNet(nn.Layer):
return
x
return
x
class
FC
(
nn
.
Layer
):
def
__init__
(
self
,
input_ch
,
output_ch
):
super
(
FC
,
self
).
__init__
()
weight_attr
=
ParamAttr
(
initializer
=
XavierNormal
())
self
.
fc
=
paddle
.
nn
.
Linear
(
input_ch
,
output_ch
,
weight_attr
=
weight_attr
)
def
forward
(
self
,
x
):
out
=
self
.
fc
(
x
)
return
out
@
register
@
register
class
PPLCNetEmbedding
(
nn
.
Layer
):
class
PPLCNetEmbedding
(
nn
.
Layer
):
"""
"""
...
@@ -262,7 +273,7 @@ class PPLCNetEmbedding(nn.Layer):
...
@@ -262,7 +273,7 @@ class PPLCNetEmbedding(nn.Layer):
def
__init__
(
self
,
scale
=
2.5
,
input_ch
=
1280
,
output_ch
=
512
):
def
__init__
(
self
,
scale
=
2.5
,
input_ch
=
1280
,
output_ch
=
512
):
super
(
PPLCNetEmbedding
,
self
).
__init__
()
super
(
PPLCNetEmbedding
,
self
).
__init__
()
self
.
backbone
=
PPLCNet
(
scale
=
scale
)
self
.
backbone
=
PPLCNet
(
scale
=
scale
)
self
.
neck
=
nn
.
Linear
(
input_ch
,
output_ch
)
self
.
neck
=
FC
(
input_ch
,
output_ch
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
feat
=
self
.
backbone
(
x
)
feat
=
self
.
backbone
(
x
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录