Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
41e1a86c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41e1a86c
编写于
4月 21, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add center loss
上级
05770197
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
657 addition
and
68 deletion
+657
-68
ppcls/arch/gears/fc.py
ppcls/arch/gears/fc.py
+16
-3
ppcls/arch/utils.py
ppcls/arch/utils.py
+48
-2
ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml
...nfigs/PersonReID/ResNet50_strong_baseline_market1501.yaml
+178
-0
ppcls/engine/engine.py
ppcls/engine/engine.py
+12
-0
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+204
-25
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+18
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-1
ppcls/loss/centerloss.py
ppcls/loss/centerloss.py
+56
-36
ppcls/loss/triplet.py
ppcls/loss/triplet.py
+120
-0
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+4
-1
未找到文件。
ppcls/arch/gears/fc.py
浏览文件 @
41e1a86c
...
...
@@ -19,16 +19,29 @@ from __future__ import print_function
import
paddle
import
paddle.nn
as
nn
from
ppcls.arch.utils
import
get_param_attr_dict
class
FC
(
nn
.
Layer
):
def
__init__
(
self
,
embedding_size
,
class_num
):
def
__init__
(
self
,
embedding_size
,
class_num
,
**
kwargs
):
super
(
FC
,
self
).
__init__
()
self
.
embedding_size
=
embedding_size
self
.
class_num
=
class_num
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
XavierNormal
())
self
.
fc
=
paddle
.
nn
.
Linear
(
self
.
embedding_size
,
self
.
class_num
,
weight_attr
=
weight_attr
)
if
'weight_attr'
in
kwargs
:
weight_attr
=
get_param_attr_dict
(
kwargs
[
'weight_attr'
],
None
)
bias_attr
=
None
if
'bias_attr'
in
kwargs
:
bias_attr
=
get_param_attr_dict
(
kwargs
[
'bias_attr'
],
None
)
self
.
fc
=
nn
.
Linear
(
self
.
embedding_size
,
self
.
class_num
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
def
forward
(
self
,
input
,
label
=
None
):
out
=
self
.
fc
(
input
)
...
...
ppcls/arch/utils.py
浏览文件 @
41e1a86c
...
...
@@ -14,9 +14,11 @@
import
six
import
types
import
paddle
from
difflib
import
SequenceMatcher
from
.
import
backbone
from
typing
import
Any
,
Dict
,
Union
def
get_architectures
():
...
...
@@ -31,8 +33,8 @@ def get_architectures():
def
get_blacklist_model_in_static_mode
():
from
ppcls.arch.backbone
import
distilled_vision_transformer
from
ppcls.arch.backbone
import
vision_transformer
from
ppcls.arch.backbone
import
(
distilled_vision_transformer
,
vision_transformer
)
blacklist
=
distilled_vision_transformer
.
__all__
+
vision_transformer
.
__all__
return
blacklist
...
...
@@ -51,3 +53,47 @@ def similar_architectures(name='', names=[], thresh=0.1, topk=10):
scores
.
sort
(
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
similar_names
=
[
names
[
s
[
0
]]
for
s
in
scores
[:
min
(
topk
,
len
(
scores
))]]
return
similar_names
def
get_param_attr_dict
(
ParamAttr_config
:
Union
[
None
,
bool
,
Dict
[
str
,
Dict
]]
)
->
Union
[
None
,
bool
,
paddle
.
ParamAttr
]:
"""parse ParamAttr from an dict
Args:
ParamAttr_config (Union[bool, Dict[str, Dict]]): ParamAttr_config
Returns:
Union[bool, paddle.ParamAttr]: Generated ParamAttr
"""
if
ParamAttr_config
is
None
:
return
None
if
isinstance
(
ParamAttr_config
,
bool
):
return
ParamAttr_config
ParamAttr_dict
=
{}
if
'initiliazer'
in
ParamAttr_config
:
initiliazer_cfg
=
ParamAttr_config
.
get
(
'initiliazer'
)
if
'name'
in
initiliazer_cfg
:
initiliazer_name
=
initiliazer_cfg
.
pop
(
'name'
)
ParamAttr_dict
[
'initiliazer'
]
=
getattr
(
paddle
.
nn
.
initializer
,
initiliazer_name
)(
**
initiliazer_cfg
)
else
:
raise
ValueError
(
f
"'name' must specified in initiliazer_cfg"
)
if
'learning_rate'
in
ParamAttr_config
:
# NOTE: only support an single value now
learning_rate_value
=
ParamAttr_config
.
get
(
'learning_rate'
)
if
isinstance
(
learning_rate_value
,
(
int
,
float
)):
ParamAttr_dict
[
'learning_rate'
]
=
learning_rate_value
else
:
raise
ValueError
(
f
"learning_rate_value must be float or int, but got
{
type
(
learning_rate_value
)
}
"
)
if
'regularizer'
in
ParamAttr_config
:
regularizer_cfg
=
ParamAttr_config
.
get
(
'regularizer'
)
if
'name'
in
regularizer_cfg
:
# L1Decay or L2Decay
regularizer_name
=
regularizer_cfg
.
pop
(
'name'
)
ParamAttr_dict
[
'regularizer'
]
=
getattr
(
paddle
.
regularizer
,
regularizer_name
)(
**
regularizer_cfg
)
else
:
raise
ValueError
(
f
"'name' must specified in regularizer_cfg"
)
return
paddle
.
ParamAttr
(
**
ParamAttr_dict
)
ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml
0 → 100644
浏览文件 @
41e1a86c
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
40
eval_during_train
:
True
eval_interval
:
10
epochs
:
120
print_batch_step
:
20
use_visualdl
:
False
warmup_epoch_by_epoch
:
True
eval_mode
:
"
retrieval"
re_ranking
:
True
# used for static mode and model export
image_shape
:
[
3
,
256
,
128
]
save_inference_dir
:
"
./inference"
# model architecture
Arch
:
name
:
"
RecModel"
infer_output_key
:
"
features"
infer_add_softmax
:
False
Backbone
:
name
:
"
ResNet50_last_stage_stride1"
pretrained
:
True
stem_act
:
null
BackboneStopLayer
:
name
:
"
flatten"
Neck
:
name
:
BNNeck
num_features
:
&feat_dim
2048
Head
:
name
:
"
FC"
embedding_size
:
*feat_dim
class_num
:
&class_num
751
weight_attr
:
initializer
:
name
:
Normal
std
:
0.001
bias_attr
:
False
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
-
TripletLossV3
:
weight
:
1.0
margin
:
0.3
normalize_feature
:
false
-
CenterLoss
:
weight
:
0.0005
num_classes
:
*class_num
feat_dim
:
*feat_dim
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
-
Adam
:
scope
:
model
lr
:
name
:
Piecewise
decay_epochs
:
[
30
,
60
]
values
:
[
0.00035
,
0.000035
,
0.0000035
]
warmup_epoch
:
10
warmup_start_lr
:
0.0000035
warmup_epoch_by_epoch
:
True
regularizer
:
name
:
'
L2'
coeff
:
0.0005
-
SGD
:
sope
:
TripletLossV3
lr
:
name
:
Constant
learning_rate
:
0.5
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
"
VeriWild"
image_root
:
"
./dataset/market1501/bounding_box_train"
cls_label_path
:
"
./dataset/market1501/bounding_box_train.txt"
relabel
:
True
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
RandFlipImage
:
flip_code
:
1
-
Pad
:
padding
:
10
-
RandCropImage
:
size
:
[
128
,
256
]
scale
:
[
0.8022
,
0.8022
]
ratio
:
[
0.5
,
0.5
]
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.5
sl
:
0.02
sh
:
0.4
r1
:
0.3
mean
:
[
0.4914
,
0.4822
,
0.4465
]
sampler
:
name
:
DistributedRandomIdentitySampler
batch_size
:
64
num_instances
:
4
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
"
VeriWild"
image_root
:
"
./dataset/market1501/query"
cls_label_path
:
"
./dataset/market1501/query.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
"
VeriWild"
image_root
:
"
./dataset/market1501/bounding_box_test"
cls_label_path
:
"
./dataset/market1501/bounding_box_test.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Eval
:
-
Recallk
:
topk
:
[
1
,
5
]
-
mAP
:
{}
ppcls/engine/engine.py
浏览文件 @
41e1a86c
...
...
@@ -298,12 +298,24 @@ class Engine(object):
self
.
max_iter
=
len
(
self
.
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
self
.
train_dataloader
)
if
self
.
config
[
"Global"
].
get
(
"warmup_epoch_by_epoch"
,
False
):
for
i
in
range
(
len
(
self
.
lr_sch
)):
self
.
lr_sch
[
i
].
step
()
logger
.
info
(
"lr_sch step once before first epoch, when Global.warmup_epoch_by_epoch=True"
)
for
epoch_id
in
range
(
best_metric
[
"epoch"
]
+
1
,
self
.
config
[
"Global"
][
"epochs"
]
+
1
):
acc
=
0.0
# for one epoch train
self
.
train_epoch_func
(
self
,
epoch_id
,
print_batch_step
)
if
self
.
config
[
"Global"
].
get
(
"warmup_epoch_by_epoch"
,
False
):
for
i
in
range
(
len
(
self
.
lr_sch
)):
self
.
lr_sch
[
i
].
step
()
if
self
.
use_dali
:
self
.
train_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
...
...
ppcls/engine/evaluation/retrieval.py
浏览文件 @
41e1a86c
...
...
@@ -16,6 +16,8 @@ from __future__ import division
from
__future__
import
print_function
import
platform
import
numpy
as
np
import
paddle
from
ppcls.utils
import
logger
...
...
@@ -49,34 +51,55 @@ def retrieval_eval(engine, epoch_id=0):
metric_dict
=
{
metric_key
:
0.
}
else
:
metric_dict
=
dict
()
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
similarity_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
if
query_query_id
is
not
None
:
query_id_block
=
query_id_blocks
[
block_idx
]
query_id_mask
=
(
query_id_block
!=
gallery_unique_id
.
t
())
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_mask
=
(
image_id_block
!=
gallery_img_id
.
t
())
keep_mask
=
paddle
.
logical_or
(
query_id_mask
,
image_id_mask
)
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
"float32"
)
else
:
keep_mask
=
None
metric_tmp
=
engine
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
gallery_img_id
,
keep_mask
)
reranking_flag
=
engine
.
config
[
'Global'
].
get
(
're_ranking'
,
False
)
logger
.
info
(
f
"re_ranking=
{
reranking_flag
}
"
)
if
not
reranking_flag
:
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
similarity_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
if
query_query_id
is
not
None
:
query_id_block
=
query_id_blocks
[
block_idx
]
query_id_mask
=
(
query_id_block
!=
gallery_unique_id
.
t
())
for
key
in
metric_tmp
:
if
key
not
in
metric_dict
:
metric_dict
[
key
]
=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_mask
=
(
image_id_block
!=
gallery_img_id
.
t
())
keep_mask
=
paddle
.
logical_or
(
query_id_mask
,
image_id_mask
)
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
"float32"
)
else
:
metric_dict
[
key
]
+=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
keep_mask
=
None
metric_tmp
=
engine
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
gallery_img_id
,
keep_mask
)
for
key
in
metric_tmp
:
if
key
not
in
metric_dict
:
metric_dict
[
key
]
=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
else
:
metric_dict
[
key
]
+=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
else
:
distmat
=
re_ranking
(
query_feas
,
gallery_feas
,
query_img_id
,
query_query_id
,
gallery_img_id
,
gallery_unique_id
,
k1
=
20
,
k2
=
6
,
lambda_value
=
0.3
)
cmc
,
mAP
=
eval_func
(
distmat
,
np
.
squeeze
(
query_img_id
.
numpy
()),
np
.
squeeze
(
gallery_img_id
.
numpy
()),
np
.
squeeze
(
query_query_id
.
numpy
()),
np
.
squeeze
(
gallery_unique_id
.
numpy
()))
for
key
in
metric_tmp
:
metric_dict
[
key
]
=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
metric_info_list
=
[]
for
key
in
metric_dict
:
if
metric_key
is
None
:
...
...
@@ -88,6 +111,162 @@ def retrieval_eval(engine, epoch_id=0):
return
metric_dict
[
metric_key
]
def
re_ranking
(
queFea
,
galFea
,
k1
=
20
,
k2
=
6
,
lambda_value
=
0.5
,
local_distmat
=
None
,
only_local
=
False
):
# if feature vector is numpy, you should use 'paddle.tensor' transform it to tensor
query_num
=
queFea
.
shape
[
0
]
all_num
=
query_num
+
galFea
.
shape
[
0
]
if
only_local
:
original_dist
=
local_distmat
else
:
feat
=
paddle
.
concat
([
queFea
,
galFea
])
logger
.
info
(
'using GPU to compute original distance'
)
# L2 distance
distmat
=
paddle
.
pow
(
feat
,
2
).
sum
(
axis
=
1
,
keepdim
=
True
).
expand
([
all_num
,
all_num
])
+
\
paddle
.
pow
(
feat
,
2
).
sum
(
axis
=
1
,
keepdim
=
True
).
expand
([
all_num
,
all_num
]).
t
()
distmat
=
distmat
.
addmm
(
x
=
feat
,
y
=
feat
.
t
(),
alpha
=-
2.0
,
beta
=
1.0
)
# Cosine distance
# distmat = paddle.matmul(queFea, galFea, transpose_y=True)
# if query_query_id is not None:
# query_id_mask = (queCid != galCid.t())
# image_id_mask = (queId != galId.t())
# keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
# distmat = distmat * keep_mask.astype("float32")
original_dist
=
distmat
.
cpu
().
numpy
()
del
feat
if
local_distmat
is
not
None
:
original_dist
=
original_dist
+
local_distmat
gallery_num
=
original_dist
.
shape
[
0
]
original_dist
=
np
.
transpose
(
original_dist
/
np
.
max
(
original_dist
,
axis
=
0
))
V
=
np
.
zeros_like
(
original_dist
).
astype
(
np
.
float16
)
initial_rank
=
np
.
argsort
(
original_dist
).
astype
(
np
.
int32
)
logger
.
info
(
'starting re_ranking'
)
for
i
in
range
(
all_num
):
# k-reciprocal neighbors
forward_k_neigh_index
=
initial_rank
[
i
,
:
k1
+
1
]
backward_k_neigh_index
=
initial_rank
[
forward_k_neigh_index
,
:
k1
+
1
]
fi
=
np
.
where
(
backward_k_neigh_index
==
i
)[
0
]
k_reciprocal_index
=
forward_k_neigh_index
[
fi
]
k_reciprocal_expansion_index
=
k_reciprocal_index
for
j
in
range
(
len
(
k_reciprocal_index
)):
candidate
=
k_reciprocal_index
[
j
]
candidate_forward_k_neigh_index
=
initial_rank
[
candidate
,
:
int
(
np
.
around
(
k1
/
2
))
+
1
]
candidate_backward_k_neigh_index
=
initial_rank
[
candidate_forward_k_neigh_index
,
:
int
(
np
.
around
(
k1
/
2
))
+
1
]
fi_candidate
=
np
.
where
(
candidate_backward_k_neigh_index
==
candidate
)[
0
]
candidate_k_reciprocal_index
=
candidate_forward_k_neigh_index
[
fi_candidate
]
if
len
(
np
.
intersect1d
(
candidate_k_reciprocal_index
,
k_reciprocal_index
))
>
2
/
3
*
len
(
candidate_k_reciprocal_index
):
k_reciprocal_expansion_index
=
np
.
append
(
k_reciprocal_expansion_index
,
candidate_k_reciprocal_index
)
k_reciprocal_expansion_index
=
np
.
unique
(
k_reciprocal_expansion_index
)
weight
=
np
.
exp
(
-
original_dist
[
i
,
k_reciprocal_expansion_index
])
V
[
i
,
k_reciprocal_expansion_index
]
=
weight
/
np
.
sum
(
weight
)
all_num_cost
=
time
.
time
()
-
t
original_dist
=
original_dist
[:
query_num
,
]
if
k2
!=
1
:
V_qe
=
np
.
zeros_like
(
V
,
dtype
=
np
.
float16
)
for
i
in
range
(
all_num
):
V_qe
[
i
,
:]
=
np
.
mean
(
V
[
initial_rank
[
i
,
:
k2
],
:],
axis
=
0
)
V
=
V_qe
del
V_qe
del
initial_rank
invIndex
=
[]
for
i
in
range
(
gallery_num
):
invIndex
.
append
(
np
.
where
(
V
[:,
i
]
!=
0
)[
0
])
jaccard_dist
=
np
.
zeros_like
(
original_dist
,
dtype
=
np
.
float16
)
gallery_num_cost
=
time
.
time
()
-
t
for
i
in
range
(
query_num
):
temp_min
=
np
.
zeros
(
shape
=
[
1
,
gallery_num
],
dtype
=
np
.
float16
)
indNonZero
=
np
.
where
(
V
[
i
,
:]
!=
0
)[
0
]
indImages
=
[
invIndex
[
ind
]
for
ind
in
indNonZero
]
for
j
in
range
(
len
(
indNonZero
)):
temp_min
[
0
,
indImages
[
j
]]
=
temp_min
[
0
,
indImages
[
j
]]
+
np
.
minimum
(
V
[
i
,
indNonZero
[
j
]],
V
[
indImages
[
j
],
indNonZero
[
j
]])
jaccard_dist
[
i
]
=
1
-
temp_min
/
(
2
-
temp_min
)
final_dist
=
jaccard_dist
*
(
1
-
lambda_value
)
+
original_dist
*
lambda_value
del
original_dist
del
V
del
jaccard_dist
final_dist
=
final_dist
[:
query_num
,
query_num
:]
query_num_cost
=
time
.
time
()
-
t
return
final_dist
def
eval_func
(
distmat
,
q_pids
,
g_pids
,
q_camids
,
g_camids
,
max_rank
=
50
):
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
num_q
,
num_g
=
distmat
.
shape
if
num_g
<
max_rank
:
max_rank
=
num_g
print
(
"Note: number of gallery samples is quite small, got {}"
.
format
(
num_g
))
indices
=
np
.
argsort
(
distmat
,
axis
=
1
)
matches
=
(
g_pids
[
indices
]
==
q_pids
[:,
np
.
newaxis
]).
astype
(
np
.
int32
)
# compute cmc curve for each query
all_cmc
=
[]
all_AP
=
[]
num_valid_q
=
0.
# number of valid query
for
q_idx
in
range
(
num_q
):
# get query pid and camid
q_pid
=
q_pids
[
q_idx
]
q_camid
=
q_camids
[
q_idx
]
# remove gallery samples that have the same pid and camid with query
order
=
indices
[
q_idx
]
remove
=
(
g_pids
[
order
]
==
q_pid
)
&
(
g_camids
[
order
]
==
q_camid
)
keep
=
np
.
invert
(
remove
)
# compute cmc curve
# binary vector, positions with value 1 are correct matches
orig_cmc
=
matches
[
q_idx
][
keep
]
if
not
np
.
any
(
orig_cmc
):
# this condition is true when query identity does not appear in gallery
continue
cmc
=
orig_cmc
.
cumsum
()
cmc
[
cmc
>
1
]
=
1
all_cmc
.
append
(
cmc
[:
max_rank
])
num_valid_q
+=
1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel
=
orig_cmc
.
sum
()
tmp_cmc
=
orig_cmc
.
cumsum
()
tmp_cmc
=
[
x
/
(
i
+
1.
)
for
i
,
x
in
enumerate
(
tmp_cmc
)]
tmp_cmc
=
np
.
asarray
(
tmp_cmc
)
*
orig_cmc
AP
=
tmp_cmc
.
sum
()
/
num_rel
all_AP
.
append
(
AP
)
assert
num_valid_q
>
0
,
"Error: all query identities do not appear in gallery"
all_cmc
=
np
.
asarray
(
all_cmc
).
astype
(
np
.
float32
)
all_cmc
=
all_cmc
.
sum
(
0
)
/
num_valid_q
mAP
=
np
.
mean
(
all_AP
)
return
all_cmc
,
mAP
def
cal_feature
(
engine
,
name
=
'gallery'
):
all_feas
=
None
all_image_id
=
None
...
...
ppcls/engine/train/train.py
浏览文件 @
41e1a86c
...
...
@@ -63,9 +63,27 @@ def train_epoch(engine, epoch_id, print_batch_step):
loss_dict
[
"loss"
].
backward
()
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
[
i
].
step
()
if
hasattr
(
engine
.
model
.
neck
,
'bn'
):
engine
.
model
.
neck
.
bn
.
bias
.
grad
.
set_value
(
paddle
.
zeros_like
(
engine
.
model
.
neck
.
bn
.
bias
.
grad
))
# clear grad
for
i
in
range
(
len
(
engine
.
optimizer
)):
# manually scale up grad of center_loss
if
i
==
1
:
for
j
in
range
(
len
(
engine
.
train_loss_func
.
loss_func
)):
if
len
(
engine
.
train_loss_func
.
loss_func
[
j
].
parameters
(
))
==
0
:
continue
for
param
in
engine
.
train_loss_func
.
loss_func
[
j
].
parameters
():
if
hasattr
(
param
,
'grad'
)
and
param
.
grad
is
not
None
:
param
.
grad
.
set_value
(
param
.
grad
*
(
1.0
/
engine
.
train_loss_func
.
loss_weight
[
j
]))
engine
.
optimizer
[
i
].
clear_grad
()
# step lr
for
i
in
range
(
len
(
engine
.
lr_sch
)):
engine
.
lr_sch
[
i
].
step
()
...
...
ppcls/loss/__init__.py
浏览文件 @
41e1a86c
...
...
@@ -11,7 +11,7 @@ from .emlloss import EmlLoss
from
.msmloss
import
MSMLoss
from
.npairsloss
import
NpairsLoss
from
.trihardloss
import
TriHardLoss
from
.triplet
import
TripletLoss
,
TripletLossV2
from
.triplet
import
TripletLoss
,
TripletLossV2
,
TripletLossV3
from
.supconloss
import
SupConLoss
from
.pairwisecosface
import
PairwiseCosface
from
.dmlloss
import
DMLLoss
...
...
ppcls/loss/centerloss.py
浏览文件 @
41e1a86c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
typing
import
Dict
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
CenterLoss
(
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
5013
,
feat_dim
=
2048
):
"""Center loss class
Args:
num_classes (int): number of classes.
feat_dim (int): number of feature dimensions.
"""
def
__init__
(
self
,
num_classes
:
int
,
feat_dim
:
int
):
super
(
CenterLoss
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
feat_dim
=
feat_dim
self
.
centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
]).
astype
(
"float64"
)
#random center
random_init_centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
])
self
.
centers
=
self
.
create_parameter
(
shape
=
(
self
.
num_classes
,
self
.
feat_dim
),
default_initializer
=
nn
.
initializer
.
Assign
(
random_init_centers
))
self
.
add_parameter
(
"centers"
,
self
.
centers
)
def
__call__
(
self
,
input
,
target
):
"""
inputs: network output: {"features: xxx", "logits": xxxx}
target: image label
def
__call__
(
self
,
input
:
Dict
[
str
,
paddle
.
Tensor
],
target
:
paddle
.
Tensor
)
->
Dict
[
str
,
paddle
.
Tensor
]:
"""compute center loss.
Args:
input (Dict[str, paddle.Tensor]): {'features': (batch_size, feature_dim), ...}.
target (paddle.Tensor): ground truth label with shape (batch_size, ).
Returns:
Dict[str, paddle.Tensor]: {'CenterLoss': loss}.
"""
feats
=
input
[
"features"
]
feats
=
input
[
'backbone'
]
labels
=
target
# squeeze labels to shape (batch_size, )
if
labels
.
ndim
>=
2
and
labels
.
shape
[
-
1
]
==
1
:
labels
=
paddle
.
squeeze
(
labels
,
axis
=
[
-
1
])
batch_size
=
feats
.
shape
[
0
]
distmat
=
paddle
.
pow
(
feats
,
2
).
sum
(
axis
=
1
,
keepdim
=
True
).
expand
([
batch_size
,
self
.
num_classes
])
+
\
paddle
.
pow
(
self
.
centers
,
2
).
sum
(
axis
=
1
,
keepdim
=
True
).
expand
([
self
.
num_classes
,
batch_size
]).
t
()
distmat
=
distmat
.
addmm
(
x
=
feats
,
y
=
self
.
centers
.
t
(),
beta
=
1
,
alpha
=-
2
)
#calc feat * feat
dist1
=
paddle
.
sum
(
paddle
.
square
(
feats
),
axis
=
1
,
keepdim
=
True
)
dist1
=
paddle
.
expand
(
dist1
,
[
batch_size
,
self
.
num_classes
])
#dist2 of centers
dist2
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
keepdim
=
True
)
#num_classes
dist2
=
paddle
.
expand
(
dist2
,
[
self
.
num_classes
,
batch_size
]).
astype
(
"float64"
)
dist2
=
paddle
.
transpose
(
dist2
,
[
1
,
0
])
#first x * x + y * y
distmat
=
paddle
.
add
(
dist1
,
dist2
)
tmp
=
paddle
.
matmul
(
feats
,
paddle
.
transpose
(
self
.
centers
,
[
1
,
0
]))
distmat
=
distmat
-
2.0
*
tmp
#generate the mask
classes
=
paddle
.
arange
(
self
.
num_classes
).
astype
(
"int64"
)
labels
=
paddle
.
expand
(
paddle
.
unsqueeze
(
labels
,
1
),
(
batch_size
,
self
.
num_classes
))
mask
=
paddle
.
equal
(
paddle
.
expand
(
classes
,
[
batch_size
,
self
.
num_classes
]),
labels
).
astype
(
"float64"
)
#get mask
dist
=
paddle
.
multiply
(
distmat
,
mask
)
loss
=
paddle
.
sum
(
paddle
.
clip
(
dist
,
min
=
1e-12
,
max
=
1e+12
))
/
batch_size
classes
=
paddle
.
arange
(
self
.
num_classes
).
astype
(
labels
.
dtype
)
labels
=
labels
.
unsqueeze
(
1
).
expand
([
batch_size
,
self
.
num_classes
])
mask
=
labels
.
equal
(
classes
.
expand
([
batch_size
,
self
.
num_classes
]))
dist
=
distmat
*
mask
.
astype
(
feats
.
dtype
)
loss
=
dist
.
clip
(
min
=
1e-12
,
max
=
1e+12
).
sum
()
/
batch_size
# return loss
return
{
'CenterLoss'
:
loss
}
ppcls/loss/triplet.py
浏览文件 @
41e1a86c
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
typing
import
Tuple
import
paddle
import
paddle.nn
as
nn
...
...
@@ -135,3 +136,122 @@ class TripletLoss(nn.Layer):
y
=
paddle
.
ones_like
(
dist_an
)
loss
=
self
.
ranking_loss
(
dist_an
,
dist_ap
,
y
)
return
{
"TripletLoss"
:
loss
}
class
TripletLossV3
(
nn
.
Layer
):
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
def
__init__
(
self
,
margin
=
None
,
normalize_feature
=
False
):
super
(
TripletLossV3
,
self
).
__init__
()
self
.
normalize_feature
=
normalize_feature
self
.
margin
=
margin
if
margin
is
not
None
:
self
.
ranking_loss
=
nn
.
MarginRankingLoss
(
margin
=
margin
)
else
:
self
.
ranking_loss
=
nn
.
SoftMarginLoss
()
def
forward
(
self
,
input
,
target
):
global_feat
=
input
[
"backbone"
]
if
self
.
normalize_feature
:
global_feat
=
self
.
_normalize
(
global_feat
,
axis
=-
1
)
dist_mat
=
self
.
_euclidean_dist
(
global_feat
,
global_feat
)
dist_ap
,
dist_an
=
self
.
_hard_example_mining
(
dist_mat
,
target
)
y
=
paddle
.
ones_like
(
dist_an
)
if
self
.
margin
is
not
None
:
loss
=
self
.
ranking_loss
(
dist_an
,
dist_ap
,
y
)
return
{
"TripletLossV3"
:
loss
}
def
_normalize
(
self
,
x
:
paddle
.
Tensor
,
axis
:
int
=-
1
)
->
paddle
.
Tensor
:
"""Normalizing to unit length along the specified dimension.
Args:
x (paddle.Tensor): (batch_size, feature_dim)
axis (int, optional): normalization dim. Defaults to -1.
Returns:
paddle.Tensor: (batch_size, feature_dim)
"""
x
=
1.
*
x
/
(
paddle
.
norm
(
x
,
2
,
axis
,
keepdim
=
True
).
expand_as
(
x
)
+
1e-12
)
return
x
def
_euclidean_dist
(
self
,
x
:
paddle
.
Tensor
,
y
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""compute euclidean distance between two batched vectors
Args:
x (paddle.Tensor): (N, feature_dim)
y (paddle.Tensor): (M, feature_dim)
Returns:
paddle.Tensor: (N, M)
"""
m
,
n
=
x
.
shape
[
0
],
y
.
shape
[
0
]
d
=
x
.
shape
[
1
]
xx
=
paddle
.
pow
(
x
,
2
).
sum
(
1
,
keepdim
=
True
).
expand
([
m
,
n
])
yy
=
paddle
.
pow
(
y
,
2
).
sum
(
1
,
keepdim
=
True
).
expand
([
n
,
m
]).
t
()
dist
=
xx
+
yy
dist
=
dist
.
addmm
(
x
,
y
.
t
(),
alpha
=-
2
,
beta
=
1
)
# dist = dist - 2*(x@y.t())
dist
=
dist
.
clip
(
min
=
1e-12
).
sqrt
()
# for numerical stability
return
dist
def
_hard_example_mining
(
self
,
dist_mat
:
paddle
.
Tensor
,
labels
:
paddle
.
Tensor
,
return_inds
:
bool
=
False
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat (paddle.Tensor): pair wise distance between samples, [N, N]
labels (paddle.Tensor): labels, [N, ]
return_inds (bool, optional): whether to return the indices . Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]: [(N, ), (N, )]
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert
len
(
dist_mat
.
shape
)
==
2
assert
dist_mat
.
shape
[
0
]
==
dist_mat
.
shape
[
1
]
N
=
dist_mat
.
shape
[
0
]
# shape [N, N]
is_pos
=
labels
.
expand
([
N
,
N
]).
equal
(
labels
.
expand
([
N
,
N
]).
t
())
is_neg
=
labels
.
expand
([
N
,
N
]).
not_equal
(
labels
.
expand
([
N
,
N
]).
t
())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap
=
paddle
.
max
(
dist_mat
[
is_pos
].
reshape
([
N
,
-
1
]),
1
,
keepdim
=
True
)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an
=
paddle
.
min
(
dist_mat
[
is_neg
].
reshape
([
N
,
-
1
]),
1
,
keepdim
=
True
)
# shape [N]
dist_ap
=
dist_ap
.
squeeze
(
1
)
dist_an
=
dist_an
.
squeeze
(
1
)
if
return_inds
:
# shape [N, N]
ind
=
(
labels
.
new
().
resize_as_
(
labels
)
.
copy_
(
paddle
.
arange
(
0
,
N
).
long
())
.
unsqueeze
(
0
).
expand
(
N
,
N
))
# shape [N, 1]
p_inds
=
paddle
.
gather
(
ind
[
is_pos
].
reshape
([
N
,
-
1
]),
1
,
relative_p_inds
.
data
)
n_inds
=
paddle
.
gather
(
ind
[
is_neg
].
reshape
([
N
,
-
1
]),
1
,
relative_n_inds
.
data
)
# shape [N]
p_inds
=
p_inds
.
squeeze
(
1
)
n_inds
=
n_inds
.
squeeze
(
1
)
return
dist_ap
,
dist_an
,
p_inds
,
n_inds
return
dist_ap
,
dist_an
ppcls/optimizer/__init__.py
浏览文件 @
41e1a86c
...
...
@@ -103,8 +103,11 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
if
optim_scope
.
endswith
(
"Loss"
):
# optimizer for loss
for
m
in
model_list
[
i
].
sublayers
(
True
):
if
m
.
__class_
name
==
optim_scope
:
if
m
.
__class_
_
.
__name__
==
optim_scope
:
optim_model
.
append
(
m
)
elif
optim_scope
==
"model"
:
# opmizer for entire model
optim_model
.
append
(
model_list
[
i
])
else
:
# opmizer for module in model, such as backbone, neck, head...
if
hasattr
(
model_list
[
i
],
optim_scope
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录