Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
ff232a58
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ff232a58
编写于
6月 03, 2021
作者:
B
Bin Lu
提交者:
GitHub
6月 03, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #779 from RainFrost1/develop_reg
Add TrainerReID
上级
6c4de88f
3ef99930
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
383 addition
and
9 deletion
+383
-9
ppcls/configs/Vehicle/ResNet50_ReID.yaml
ppcls/configs/Vehicle/ResNet50_ReID.yaml
+161
-0
ppcls/data/__init__.py
ppcls/data/__init__.py
+3
-4
ppcls/engine/trainer.py
ppcls/engine/trainer.py
+6
-4
ppcls/engine/trainer_reid.py
ppcls/engine/trainer_reid.py
+208
-0
tools/train.py
tools/train.py
+5
-1
未找到文件。
ppcls/configs/Vehicle/ResNet50_ReID.yaml
0 → 100644
浏览文件 @
ff232a58
# global configs
Trainer
:
name
:
TrainerReID
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
class_num
:
30671
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
160
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
"
./inference"
num_split
:
1
feature_normalize
:
True
# model architecture
Arch
:
name
:
"
RecModel"
Backbone
:
name
:
"
ResNet50"
Stoplayer
:
name
:
"
flatten_0"
output_dim
:
2048
embedding_size
:
512
Head
:
name
:
"
ArcMargin"
embedding_size
:
512
class_num
:
431
margin
:
0.15
scale
:
32
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
-
TripletLossV2
:
weight
:
1.0
margin
:
0.5
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
MultiStepDecay
learning_rate
:
0.01
milestones
:
[
30
,
60
,
70
,
80
,
90
,
100
,
120
,
140
]
gamma
:
0.5
verbose
:
False
last_epoch
:
-1
regularizer
:
name
:
'
L2'
coeff
:
0.0005
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
"
VeriWild"
image_root
:
"
/work/dataset/VeRI-Wild/images/"
cls_label_path
:
"
/work/dataset/VeRI-Wild/train_test_split/debug_train.txt"
transform_ops
:
-
ResizeImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
AugMix
:
prob
:
0.5
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.5
sl
:
0.02
sh
:
0.4
r1
:
0.3
mean
:
[
0.
,
0.
,
0.
]
sampler
:
name
:
DistributedRandomIdentitySampler
batch_size
:
128
num_instances
:
2
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
6
use_shared_memory
:
False
Query
:
# TOTO: modify to the latest trainer
dataset
:
name
:
"
VeriWild"
image_root
:
"
/work/dataset/VeRI-Wild/images"
cls_label_path
:
"
/work/dataset/VeRI-Wild/train_test_split/debug_test_query.txt"
transform_ops
:
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
6
use_shared_memory
:
False
Gallery
:
# TOTO: modify to the latest trainer
dataset
:
name
:
"
VeriWild"
image_root
:
"
/work/dataset/VeRI-Wild/images"
cls_label_path
:
"
/work/dataset/VeRI-Wild/train_test_split/debug_test.txt"
transform_ops
:
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
6
use_shared_memory
:
False
Infer
:
infer_imgs
:
"
docs/images/whl/demo.jpg"
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
ppcls/data/__init__.py
浏览文件 @
ff232a58
...
...
@@ -27,14 +27,13 @@ from ppcls.data.dataloader.common_dataset import create_operators
from
ppcls.data.dataloader.vehicle_dataset
import
CompCars
,
VeriWild
# sampler
from
ppcls.data.dataloader
import
DistributedRandomIdentitySampler
from
ppcls.data.dataloader.DistributedRandomIdentitySampler
import
DistributedRandomIdentitySampler
from
ppcls.data.preprocess
import
transform
def
build_dataloader
(
config
,
mode
,
device
,
seed
=
None
):
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
],
"Mode should be Train, Eval
or Test.
"
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
],
"Mode should be Train, Eval
, Test, Gallery or Query
"
# build dataset
config_dataset
=
config
[
mode
][
'dataset'
]
config_dataset
=
copy
.
deepcopy
(
config_dataset
)
...
...
ppcls/engine/trainer.py
浏览文件 @
ff232a58
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -109,8 +109,10 @@ class Trainer(object):
def
train
(
self
):
# build train loss and metric info
loss_func
=
self
.
_build_loss_info
(
self
.
config
[
"Loss"
])
metric_func
=
self
.
_build_metric_info
(
self
.
config
[
"Metric"
])
if
"Metric"
in
self
.
config
:
metric_func
=
self
.
_build_metric_info
(
self
.
config
[
"Metric"
])
else
:
metric_func
=
None
train_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Train"
,
self
.
device
)
...
...
@@ -156,7 +158,7 @@ class Trainer(object):
else
:
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
# calc loss
loss_dict
=
loss_func
(
out
,
batch
[
-
1
])
loss_dict
=
loss_func
(
out
,
batch
[
1
])
for
key
in
loss_dict
:
if
not
key
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
...
...
ppcls/engine/trainer_reid.py
0 → 100644
浏览文件 @
ff232a58
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../../'
)))
import
numpy
as
np
import
paddle
from
.trainer
import
Trainer
from
ppcls.utils
import
logger
from
ppcls.data
import
build_dataloader
class
TrainerReID
(
Trainer
):
def
__init__
(
self
,
config
,
mode
=
"train"
):
super
().
__init__
(
config
,
mode
)
self
.
gallery_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Gallery"
,
self
.
device
)
self
.
query_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Query"
,
self
.
device
)
@
paddle
.
no_grad
()
def
eval
(
self
,
epoch_id
=
0
):
output_info
=
dict
()
self
.
model
.
eval
()
print_batch_step
=
self
.
config
[
"Global"
][
"print_batch_step"
]
# step1. build gallery
gallery_feas
,
gallery_img_id
,
gallery_camera_id
=
self
.
_cal_feature
(
name
=
'gallery'
)
query_feas
,
query_img_id
,
query_camera_id
=
self
.
_cal_feature
(
name
=
'query'
)
# step2. do evaluation
if
"num_split"
in
self
.
config
[
"Global"
]:
num_split
=
self
.
config
[
"Global"
][
"num_split"
]
else
:
num_split
=
1
fea_blocks
=
paddle
.
split
(
query_feas
,
num_or_sections
=
1
)
total_similarities_matrix
=
None
for
block_fea
in
fea_blocks
:
similarities_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
if
total_similarities_matrix
is
None
:
total_similarities_matrix
=
similarities_matrix
else
:
total_similarities_matrix
=
paddle
.
concat
(
[
total_similarities_matrix
,
similarities_matrix
])
# distmat = (1 - total_similarities_matrix).numpy()
q_pids
=
query_img_id
.
numpy
().
reshape
((
query_img_id
.
shape
[
0
]))
g_pids
=
gallery_img_id
.
numpy
().
reshape
((
gallery_img_id
.
shape
[
0
]))
if
query_camera_id
is
not
None
and
gallery_camera_id
is
not
None
:
q_camids
=
query_camera_id
.
numpy
().
reshape
(
(
query_camera_id
.
shape
[
0
]))
g_camids
=
gallery_camera_id
.
numpy
().
reshape
(
(
gallery_camera_id
.
shape
[
0
]))
max_rank
=
50
num_q
,
num_g
=
total_similarities_matrix
.
shape
if
num_g
<
max_rank
:
max_rank
=
num_g
print
(
'Note: number of gallery samples is quite small, got {}'
.
format
(
num_g
))
# indices = np.argsort(distmat, axis=1)
indices
=
paddle
.
argsort
(
total_similarities_matrix
,
axis
=
1
,
descending
=
True
).
numpy
()
matches
=
(
g_pids
[
indices
]
==
q_pids
[:,
np
.
newaxis
]).
astype
(
np
.
int32
)
# compute cmc curve for each query
all_cmc
=
[]
all_AP
=
[]
all_INP
=
[]
num_valid_q
=
0.
# number of valid query
for
q_idx
in
range
(
num_q
):
# get query pid and camid
q_pid
=
q_pids
[
q_idx
]
q_camid
=
q_camids
[
q_idx
]
# remove gallery samples that have the same pid and camid with query
order
=
indices
[
q_idx
]
if
query_camera_id
is
not
None
and
gallery_camera_id
is
not
None
:
remove
=
(
g_pids
[
order
]
==
q_pid
)
&
(
g_camids
[
order
]
==
q_camid
)
else
:
remove
=
g_pids
[
order
]
==
q_pid
keep
=
np
.
invert
(
remove
)
# compute cmc curve
raw_cmc
=
matches
[
q_idx
][
keep
]
# binary vector, positions with value 1 are correct matches
if
not
np
.
any
(
raw_cmc
):
# this condition is true when query identity does not appear in gallery
continue
cmc
=
raw_cmc
.
cumsum
()
pos_idx
=
np
.
where
(
raw_cmc
==
1
)
max_pos_idx
=
np
.
max
(
pos_idx
)
inp
=
cmc
[
max_pos_idx
]
/
(
max_pos_idx
+
1.0
)
all_INP
.
append
(
inp
)
cmc
[
cmc
>
1
]
=
1
all_cmc
.
append
(
cmc
[:
max_rank
])
num_valid_q
+=
1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel
=
raw_cmc
.
sum
()
tmp_cmc
=
raw_cmc
.
cumsum
()
tmp_cmc
=
[
x
/
(
i
+
1.
)
for
i
,
x
in
enumerate
(
tmp_cmc
)]
tmp_cmc
=
np
.
asarray
(
tmp_cmc
)
*
raw_cmc
AP
=
tmp_cmc
.
sum
()
/
num_rel
all_AP
.
append
(
AP
)
assert
num_valid_q
>
0
,
'Error: all query identities do not appear in gallery'
all_cmc
=
np
.
asarray
(
all_cmc
).
astype
(
np
.
float32
)
all_cmc
=
all_cmc
.
sum
(
0
)
/
num_valid_q
mAP
=
np
.
mean
(
all_AP
)
mINP
=
np
.
mean
(
all_INP
)
logger
.
info
(
"[Eval][Epoch {}]: mAP: {:.5f}, mINP: {:.5f},rank_1: {:.5f}, rank_5: {:.5f}"
.
format
(
epoch_id
,
mAP
,
mINP
,
all_cmc
[
0
],
all_cmc
[
4
]))
return
mAP
def
_cal_feature
(
self
,
name
=
'gallery'
):
all_feas
=
None
all_image_id
=
None
all_camera_id
=
None
if
name
==
'gallery'
:
dataloader
=
self
.
gallery_dataloader
elif
name
==
'query'
:
dataloader
=
self
.
query_dataloader
else
:
raise
RuntimeError
(
"Only support gallery or query dataset"
)
has_cam_id
=
False
for
idx
,
batch
in
enumerate
(
dataloader
(
)):
# load is very time-consuming
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
])
if
len
(
batch
)
==
3
:
has_cam_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
])
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
batch_feas
=
out
[
"features"
]
# do norm
if
self
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
feas_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
batch_feas
),
axis
=
1
,
keepdim
=
True
))
batch_feas
=
paddle
.
divide
(
batch_feas
,
feas_norm
)
batch_feas
=
batch_feas
batch_image_labels
=
batch
[
1
]
if
has_cam_id
:
batch_camera_labels
=
batch
[
2
]
if
all_feas
is
None
:
all_feas
=
batch_feas
if
has_cam_id
:
all_camera_id
=
batch
[
2
]
all_image_id
=
batch
[
1
]
else
:
all_feas
=
paddle
.
concat
([
all_feas
,
batch_feas
])
all_image_id
=
paddle
.
concat
([
all_image_id
,
batch
[
1
]])
if
has_cam_id
:
all_camera_id
=
paddle
.
concat
([
all_camera_id
,
batch
[
2
]])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
feat_list
=
[]
img_id_list
=
[]
cam_id_list
=
[]
paddle
.
distributed
.
all_gather
(
feat_list
,
all_feas
)
paddle
.
distributed
.
all_gather
(
img_id_list
,
all_image_id
)
all_feas
=
paddle
.
concat
(
feat_list
,
axis
=
0
)
all_image_id
=
paddle
.
concat
(
img_id_list
,
axis
=
0
)
if
has_cam_id
:
paddle
.
distributed
.
all_gather
(
cam_id_list
,
all_camera_id
)
all_camera_id
=
paddle
.
concat
(
cam_id_list
,
axis
=
0
)
logger
.
info
(
"Build {} done, all feat shape: {}, begin to eval.."
.
format
(
name
,
all_feas
.
shape
))
return
all_feas
,
all_image_id
,
all_camera_id
tools/train.py
浏览文件 @
ff232a58
...
...
@@ -22,9 +22,13 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from
ppcls.utils
import
config
from
ppcls.engine.trainer
import
Trainer
from
ppcls.engine.trainer_reid
import
TrainerReID
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
trainer
=
Trainer
(
config
,
mode
=
"train"
)
if
"Trainer"
in
config
:
trainer
=
eval
(
config
[
"Trainer"
][
"name"
])(
config
,
mode
=
"train"
)
else
:
trainer
=
Trainer
(
config
,
mode
=
"train"
)
trainer
.
train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录