Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
ed098b3c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed098b3c
编写于
6月 04, 2021
作者:
W
weishengyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify trainer
上级
e61af9cf
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
184 addition
and
256 deletion
+184
-256
ppcls/engine/trainer.py
ppcls/engine/trainer.py
+183
-43
ppcls/engine/trainer_reid.py
ppcls/engine/trainer_reid.py
+0
-208
tools/train.py
tools/train.py
+1
-5
未找到文件。
ppcls/engine/trainer.py
浏览文件 @
ed098b3c
...
...
@@ -81,6 +81,17 @@ class Trainer(object):
self
.
vdl_writer
=
LogWriter
(
logdir
=
vdl_writer_path
)
logger
.
info
(
'train with paddle {} and device {}'
.
format
(
paddle
.
__version__
,
self
.
device
))
# init members
self
.
train_dataloader
=
None
self
.
eval_dataloader
=
None
self
.
gallery_dataloader
=
None
self
.
query_dataloader
=
None
self
.
eval_mode
=
self
.
config
[
"Global"
].
get
(
"eval_mode"
,
"classification"
)
self
.
train_loss_func
=
None
self
.
eval_loss_func
=
None
self
.
train_metric_func
=
None
self
.
eval_metric_func
=
None
def
_build_metric_info
(
self
,
metric_config
,
mode
=
"train"
):
"""
...
...
@@ -108,16 +119,17 @@ class Trainer(object):
def
train
(
self
):
# build train loss and metric info
loss_func
=
self
.
_build_loss_info
(
self
.
config
[
"Loss"
])
if
"Metric"
in
self
.
config
:
metric_func
=
self
.
_build_metric_info
(
self
.
config
[
"Metric"
])
else
:
metric_func
=
None
if
self
.
train_loss_func
is
None
:
self
.
train_loss_func
=
self
.
_build_loss_info
(
self
.
config
[
"Loss"
])
if
"Metric"
in
self
.
config
and
self
.
train_metric_func
is
None
:
self
.
train_metric_func
=
self
.
_build_metric_info
(
self
.
config
[
"Metric"
])
train_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Train"
,
self
.
device
)
if
self
.
train_dataloader
is
None
:
self
.
train_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Train"
,
self
.
device
)
step_each_epoch
=
len
(
train_dataloader
)
step_each_epoch
=
len
(
self
.
train_dataloader
)
optimizer
,
lr_sch
=
build_optimizer
(
self
.
config
[
"Optimizer"
],
self
.
config
[
"Global"
][
"epochs"
],
...
...
@@ -147,7 +159,7 @@ class Trainer(object):
self
.
config
[
"Global"
][
"epochs"
]
+
1
):
acc
=
0.0
self
.
model
.
train
()
for
iter_id
,
batch
in
enumerate
(
train_dataloader
()):
for
iter_id
,
batch
in
enumerate
(
self
.
train_dataloader
()):
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
1
]
=
paddle
.
to_tensor
(
batch
[
1
].
numpy
().
astype
(
"int64"
)
.
reshape
([
-
1
,
1
]))
...
...
@@ -158,15 +170,15 @@ class Trainer(object):
else
:
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
# calc loss
loss_dict
=
loss_func
(
out
,
batch
[
1
])
loss_dict
=
self
.
train_
loss_func
(
out
,
batch
[
1
])
for
key
in
loss_dict
:
if
not
key
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
batch_size
)
# calc metric
if
metric_func
is
not
None
:
metric_dict
=
metric_func
(
out
,
batch
[
-
1
])
if
self
.
train_
metric_func
is
not
None
:
metric_dict
=
self
.
train_
metric_func
(
out
,
batch
[
-
1
])
for
key
in
metric_dict
:
if
not
key
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
...
...
@@ -181,7 +193,7 @@ class Trainer(object):
])
logger
.
info
(
"[Train][Epoch {}][Iter: {}/{}]{}, {}"
.
format
(
epoch_id
,
iter_id
,
len
(
train_dataloader
),
lr_msg
,
metric_msg
))
len
(
self
.
train_dataloader
),
lr_msg
,
metric_msg
))
# step opt and lr
loss_dict
[
"loss"
].
backward
()
...
...
@@ -212,6 +224,7 @@ class Trainer(object):
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"best_model"
)
self
.
model
.
train
()
# save model
if
epoch_id
%
save_interval
==
0
:
...
...
@@ -228,20 +241,41 @@ class Trainer(object):
@
paddle
.
no_grad
()
def
eval
(
self
,
epoch_id
=
0
):
output_info
=
dict
()
if
self
.
eval_dataloader
is
None
:
self
.
eval_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Eval"
,
self
.
device
)
if
self
.
gallery_dataloader
is
None
:
self
.
gallery_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Gallery"
,
self
.
device
)
if
self
.
query_dataloader
is
None
:
self
.
query_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Query"
,
self
.
device
)
# build train loss and metric info
if
self
.
eval_loss_func
is
None
:
self
.
eval_loss_func
=
self
.
_build_loss_info
(
self
.
config
[
"Loss"
],
"eval"
)
eval_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Eval"
,
self
.
device
)
if
self
.
eval_metric_func
is
None
:
self
.
eval_metric_func
=
self
.
_build_metric_info
(
self
.
config
[
"Metric"
],
"eval"
)
self
.
model
.
eval
()
if
self
.
eval_mode
==
"classification"
:
self
.
eval_cls
(
epoch_id
)
elif
self
.
eval_mode
==
"retrieval"
:
self
.
eval_retrieval
(
epoch_id
)
else
:
logger
.
warning
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
def
eval_cls
(
self
,
epoch_id
=
0
):
output_info
=
dict
()
print_batch_step
=
self
.
config
[
"Global"
][
"print_batch_step"
]
# build train loss and metric info
loss_func
=
self
.
_build_loss_info
(
self
.
config
[
"Loss"
],
"eval"
)
metric_func
=
self
.
_build_metric_info
(
self
.
config
[
"Metric"
],
"eval"
)
metric_key
=
None
for
iter_id
,
batch
in
enumerate
(
eval_dataloader
()):
for
iter_id
,
batch
in
enumerate
(
self
.
eval_dataloader
()):
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
]).
astype
(
"float32"
)
batch
[
1
]
=
paddle
.
to_tensor
(
batch
[
1
]).
reshape
([
-
1
,
1
])
...
...
@@ -250,32 +284,32 @@ class Trainer(object):
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
else
:
out
=
self
.
model
(
batch
[
0
])
# calc
build
if
loss_func
is
not
None
:
loss_dict
=
loss_func
(
out
,
batch
[
-
1
])
# calc
loss
if
self
.
eval_
loss_func
is
not
None
:
loss_dict
=
self
.
eval_
loss_func
(
out
,
batch
[
-
1
])
for
key
in
loss_dict
:
if
not
key
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
batch_size
)
# calc metric
if
metric_func
is
not
None
:
metric_dict
=
metric_func
(
out
,
batch
[
-
1
])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
for
key
in
metric_dict
:
paddle
.
distributed
.
all_reduce
(
metric_dict
[
key
],
op
=
paddle
.
distributed
.
ReduceOp
.
SUM
)
metric_dict
[
key
]
=
metric_dict
[
key
]
/
paddle
.
distributed
.
get_world_size
()
# calc metric
if
self
.
eval_metric_func
is
not
None
:
metric_dict
=
self
.
eval_metric_func
(
out
,
batch
[
-
1
])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
for
key
in
metric_dict
:
if
metric_key
is
None
:
metric_key
=
key
if
not
key
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
paddle
.
distributed
.
all_reduce
(
metric_dict
[
key
],
op
=
paddle
.
distributed
.
ReduceOp
.
SUM
)
metric_dict
[
key
]
=
metric_dict
[
key
]
/
paddle
.
distributed
.
get_world_size
()
for
key
in
metric_dict
:
if
metric_key
is
None
:
metric_key
=
key
if
not
key
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
batch_size
)
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
batch_size
)
if
iter_id
%
print_batch_step
==
0
:
metric_msg
=
", "
.
join
([
...
...
@@ -283,7 +317,7 @@ class Trainer(object):
for
key
in
output_info
])
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}"
.
format
(
epoch_id
,
iter_id
,
len
(
eval_dataloader
),
metric_msg
))
epoch_id
,
iter_id
,
len
(
self
.
eval_dataloader
),
metric_msg
))
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
...
...
@@ -291,13 +325,119 @@ class Trainer(object):
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
self
.
model
.
train
()
# do not try to save best model
if
metric_func
is
None
:
if
self
.
eval_
metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
output_info
[
metric_key
].
avg
def
eval_retrieval
(
self
,
epoch_id
=
0
):
output_info
=
dict
()
self
.
model
.
eval
()
# step1. build gallery
gallery_feas
,
gallery_img_id
,
gallery_camera_id
=
self
.
_cal_feature
(
name
=
'gallery'
)
query_feas
,
query_img_id
,
query_camera_id
=
self
.
_cal_feature
(
name
=
'query'
)
gallery_img_id
=
paddle
.
to_tensor
([
gallery_img_id
]).
t
()
if
gallery_camera_id
is
not
None
:
gallery_camera_id
=
paddle
.
to_tensor
(
gallery_camera_id
).
t
()
query_img_id
=
paddle
.
to_tensor
(
query_img_id
)
if
query_camera_id
is
not
None
:
query_camera_id
=
paddle
.
to_tensor
(
query_camera_id
)
# step2. do evaluation
sim_block_size
=
self
.
config
[
"Global"
].
get
(
"sim_block_size"
,
1
)
sections
=
[
sim_block_size
]
*
(
len
(
query_feas
)
//
sim_block_size
)
if
not
len
(
query_feas
)
%
sim_block_size
:
sections
.
append
(
len
(
query_feas
)
%
sim_block_size
)
fea_blocks
=
paddle
.
split
(
query_feas
,
num_or_sections
=
sections
)
camera_id_blocks
=
paddle
.
split
(
query_camera_id
,
num_or_sections
=
sections
)
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
metric_key
=
None
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
similarities_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_mask
=
(
image_id_block
==
gallery_img_id
)
similarities_matrix
=
similarities_matrix
.
masked_select
(
image_id_mask
)
camera_id_block
=
camera_id_blocks
[
block_idx
]
camera_id_mask
=
(
camera_id_block
==
gallery_camera_id
)
similarities_matrix
=
similarities_matrix
.
masked_select
(
camera_id_mask
)
# calc metric
if
self
.
eval_metric_func
is
not
None
:
metric_dict
=
self
.
eval_metric_func
(
similarities_matrix
,
image_id_block
)
for
key
in
metric_dict
:
if
metric_key
is
None
:
metric_key
=
key
if
not
key
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
len
(
image_id_block
))
def
_cal_feature
(
self
,
name
=
'gallery'
):
all_feas
=
None
all_image_id
=
None
all_camera_id
=
None
if
name
==
'gallery'
:
dataloader
=
self
.
gallery_dataloader
elif
name
==
'query'
:
dataloader
=
self
.
query_dataloader
else
:
raise
RuntimeError
(
"Only support gallery or query dataset"
)
has_cam_id
=
False
for
idx
,
batch
in
enumerate
(
dataloader
(
)):
# load is very time-consuming
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
])
if
len
(
batch
)
==
3
:
has_cam_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
])
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
batch_feas
=
out
[
"features"
]
# do norm
if
self
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
feas_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
batch_feas
),
axis
=
1
,
keepdim
=
True
))
batch_feas
=
paddle
.
divide
(
batch_feas
,
feas_norm
)
if
all_feas
is
None
:
all_feas
=
batch_feas
if
has_cam_id
:
all_camera_id
=
batch
[
2
]
all_image_id
=
batch
[
1
]
else
:
all_feas
=
paddle
.
concat
([
all_feas
,
batch_feas
])
all_image_id
=
paddle
.
concat
([
all_image_id
,
batch
[
1
]])
if
has_cam_id
:
all_camera_id
=
paddle
.
concat
([
all_camera_id
,
batch
[
2
]])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
feat_list
=
[]
img_id_list
=
[]
cam_id_list
=
[]
paddle
.
distributed
.
all_gather
(
feat_list
,
all_feas
)
paddle
.
distributed
.
all_gather
(
img_id_list
,
all_image_id
)
all_feas
=
paddle
.
concat
(
feat_list
,
axis
=
0
)
all_image_id
=
paddle
.
concat
(
img_id_list
,
axis
=
0
)
if
has_cam_id
:
paddle
.
distributed
.
all_gather
(
cam_id_list
,
all_camera_id
)
all_camera_id
=
paddle
.
concat
(
cam_id_list
,
axis
=
0
)
logger
.
info
(
"Build {} done, all feat shape: {}, begin to eval.."
.
format
(
name
,
all_feas
.
shape
))
return
all_feas
,
all_image_id
,
all_camera_id
@
paddle
.
no_grad
()
def
infer
(
self
,
):
total_trainer
=
paddle
.
distributed
.
get_world_size
()
...
...
ppcls/engine/trainer_reid.py
已删除
100644 → 0
浏览文件 @
e61af9cf
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../../'
)))
import
numpy
as
np
import
paddle
from
.trainer
import
Trainer
from
ppcls.utils
import
logger
from
ppcls.data
import
build_dataloader
class
TrainerReID
(
Trainer
):
def
__init__
(
self
,
config
,
mode
=
"train"
):
super
().
__init__
(
config
,
mode
)
self
.
gallery_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Gallery"
,
self
.
device
)
self
.
query_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Query"
,
self
.
device
)
@
paddle
.
no_grad
()
def
eval
(
self
,
epoch_id
=
0
):
output_info
=
dict
()
self
.
model
.
eval
()
print_batch_step
=
self
.
config
[
"Global"
][
"print_batch_step"
]
# step1. build gallery
gallery_feas
,
gallery_img_id
,
gallery_camera_id
=
self
.
_cal_feature
(
name
=
'gallery'
)
query_feas
,
query_img_id
,
query_camera_id
=
self
.
_cal_feature
(
name
=
'query'
)
# step2. do evaluation
if
"num_split"
in
self
.
config
[
"Global"
]:
num_split
=
self
.
config
[
"Global"
][
"num_split"
]
else
:
num_split
=
1
fea_blocks
=
paddle
.
split
(
query_feas
,
num_or_sections
=
1
)
total_similarities_matrix
=
None
for
block_fea
in
fea_blocks
:
similarities_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
if
total_similarities_matrix
is
None
:
total_similarities_matrix
=
similarities_matrix
else
:
total_similarities_matrix
=
paddle
.
concat
(
[
total_similarities_matrix
,
similarities_matrix
])
# distmat = (1 - total_similarities_matrix).numpy()
q_pids
=
query_img_id
.
numpy
().
reshape
((
query_img_id
.
shape
[
0
]))
g_pids
=
gallery_img_id
.
numpy
().
reshape
((
gallery_img_id
.
shape
[
0
]))
if
query_camera_id
is
not
None
and
gallery_camera_id
is
not
None
:
q_camids
=
query_camera_id
.
numpy
().
reshape
(
(
query_camera_id
.
shape
[
0
]))
g_camids
=
gallery_camera_id
.
numpy
().
reshape
(
(
gallery_camera_id
.
shape
[
0
]))
max_rank
=
50
num_q
,
num_g
=
total_similarities_matrix
.
shape
if
num_g
<
max_rank
:
max_rank
=
num_g
print
(
'Note: number of gallery samples is quite small, got {}'
.
format
(
num_g
))
# indices = np.argsort(distmat, axis=1)
indices
=
paddle
.
argsort
(
total_similarities_matrix
,
axis
=
1
,
descending
=
True
).
numpy
()
matches
=
(
g_pids
[
indices
]
==
q_pids
[:,
np
.
newaxis
]).
astype
(
np
.
int32
)
# compute cmc curve for each query
all_cmc
=
[]
all_AP
=
[]
all_INP
=
[]
num_valid_q
=
0.
# number of valid query
for
q_idx
in
range
(
num_q
):
# get query pid and camid
q_pid
=
q_pids
[
q_idx
]
q_camid
=
q_camids
[
q_idx
]
# remove gallery samples that have the same pid and camid with query
order
=
indices
[
q_idx
]
if
query_camera_id
is
not
None
and
gallery_camera_id
is
not
None
:
remove
=
(
g_pids
[
order
]
==
q_pid
)
&
(
g_camids
[
order
]
==
q_camid
)
else
:
remove
=
g_pids
[
order
]
==
q_pid
keep
=
np
.
invert
(
remove
)
# compute cmc curve
raw_cmc
=
matches
[
q_idx
][
keep
]
# binary vector, positions with value 1 are correct matches
if
not
np
.
any
(
raw_cmc
):
# this condition is true when query identity does not appear in gallery
continue
cmc
=
raw_cmc
.
cumsum
()
pos_idx
=
np
.
where
(
raw_cmc
==
1
)
max_pos_idx
=
np
.
max
(
pos_idx
)
inp
=
cmc
[
max_pos_idx
]
/
(
max_pos_idx
+
1.0
)
all_INP
.
append
(
inp
)
cmc
[
cmc
>
1
]
=
1
all_cmc
.
append
(
cmc
[:
max_rank
])
num_valid_q
+=
1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel
=
raw_cmc
.
sum
()
tmp_cmc
=
raw_cmc
.
cumsum
()
tmp_cmc
=
[
x
/
(
i
+
1.
)
for
i
,
x
in
enumerate
(
tmp_cmc
)]
tmp_cmc
=
np
.
asarray
(
tmp_cmc
)
*
raw_cmc
AP
=
tmp_cmc
.
sum
()
/
num_rel
all_AP
.
append
(
AP
)
assert
num_valid_q
>
0
,
'Error: all query identities do not appear in gallery'
all_cmc
=
np
.
asarray
(
all_cmc
).
astype
(
np
.
float32
)
all_cmc
=
all_cmc
.
sum
(
0
)
/
num_valid_q
mAP
=
np
.
mean
(
all_AP
)
mINP
=
np
.
mean
(
all_INP
)
logger
.
info
(
"[Eval][Epoch {}]: mAP: {:.5f}, mINP: {:.5f},rank_1: {:.5f}, rank_5: {:.5f}"
.
format
(
epoch_id
,
mAP
,
mINP
,
all_cmc
[
0
],
all_cmc
[
4
]))
return
mAP
def
_cal_feature
(
self
,
name
=
'gallery'
):
all_feas
=
None
all_image_id
=
None
all_camera_id
=
None
if
name
==
'gallery'
:
dataloader
=
self
.
gallery_dataloader
elif
name
==
'query'
:
dataloader
=
self
.
query_dataloader
else
:
raise
RuntimeError
(
"Only support gallery or query dataset"
)
has_cam_id
=
False
for
idx
,
batch
in
enumerate
(
dataloader
(
)):
# load is very time-consuming
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
])
if
len
(
batch
)
==
3
:
has_cam_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
])
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
batch_feas
=
out
[
"features"
]
# do norm
if
self
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
feas_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
batch_feas
),
axis
=
1
,
keepdim
=
True
))
batch_feas
=
paddle
.
divide
(
batch_feas
,
feas_norm
)
batch_feas
=
batch_feas
batch_image_labels
=
batch
[
1
]
if
has_cam_id
:
batch_camera_labels
=
batch
[
2
]
if
all_feas
is
None
:
all_feas
=
batch_feas
if
has_cam_id
:
all_camera_id
=
batch
[
2
]
all_image_id
=
batch
[
1
]
else
:
all_feas
=
paddle
.
concat
([
all_feas
,
batch_feas
])
all_image_id
=
paddle
.
concat
([
all_image_id
,
batch
[
1
]])
if
has_cam_id
:
all_camera_id
=
paddle
.
concat
([
all_camera_id
,
batch
[
2
]])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
feat_list
=
[]
img_id_list
=
[]
cam_id_list
=
[]
paddle
.
distributed
.
all_gather
(
feat_list
,
all_feas
)
paddle
.
distributed
.
all_gather
(
img_id_list
,
all_image_id
)
all_feas
=
paddle
.
concat
(
feat_list
,
axis
=
0
)
all_image_id
=
paddle
.
concat
(
img_id_list
,
axis
=
0
)
if
has_cam_id
:
paddle
.
distributed
.
all_gather
(
cam_id_list
,
all_camera_id
)
all_camera_id
=
paddle
.
concat
(
cam_id_list
,
axis
=
0
)
logger
.
info
(
"Build {} done, all feat shape: {}, begin to eval.."
.
format
(
name
,
all_feas
.
shape
))
return
all_feas
,
all_image_id
,
all_camera_id
tools/train.py
浏览文件 @
ed098b3c
...
...
@@ -22,13 +22,9 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from
ppcls.utils
import
config
from
ppcls.engine.trainer
import
Trainer
from
ppcls.engine.trainer_reid
import
TrainerReID
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
if
"Trainer"
in
config
:
trainer
=
eval
(
config
[
"Trainer"
][
"name"
])(
config
,
mode
=
"train"
)
else
:
trainer
=
Trainer
(
config
,
mode
=
"train"
)
trainer
=
Trainer
(
config
,
mode
=
"train"
)
trainer
.
train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录