Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
af90cd7c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
af90cd7c
编写于
4月 12, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update center loss config and related code
上级
9de22673
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
410 addition
and
67 deletion
+410
-67
ppcls/arch/gears/bnneck.py
ppcls/arch/gears/bnneck.py
+2
-2
ppcls/arch/gears/fc.py
ppcls/arch/gears/fc.py
+6
-2
ppcls/configs/Pedestrian/strong_baseline_m1.yaml
ppcls/configs/Pedestrian/strong_baseline_m1.yaml
+3
-1
ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml
ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml
+173
-0
ppcls/engine/engine.py
ppcls/engine/engine.py
+11
-1
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+8
-1
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+13
-5
ppcls/engine/train/utils.py
ppcls/engine/train/utils.py
+23
-7
ppcls/loss/centerloss.py
ppcls/loss/centerloss.py
+58
-21
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+92
-24
ppcls/optimizer/learning_rate.py
ppcls/optimizer/learning_rate.py
+17
-0
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+4
-3
未找到文件。
ppcls/arch/gears/bnneck.py
浏览文件 @
af90cd7c
...
@@ -6,8 +6,8 @@ class BNNeck(paddle.nn.Layer):
...
@@ -6,8 +6,8 @@ class BNNeck(paddle.nn.Layer):
super
(
BNNeck
,
self
).
__init__
()
super
(
BNNeck
,
self
).
__init__
()
self
.
num_filters
=
num_filters
self
.
num_filters
=
num_filters
self
.
bn
=
paddle
.
nn
.
BatchNorm1D
(
self
.
bn
=
paddle
.
nn
.
BatchNorm1D
(
self
.
num_filters
)
self
.
num_filters
)
# TODO: freeze bn.bias
# if not trainable:
# if not trainable:
# self.bn.bias.trainable = False
# self.bn.bias.trainable = False
...
...
ppcls/arch/gears/fc.py
浏览文件 @
af90cd7c
...
@@ -25,10 +25,14 @@ class FC(nn.Layer):
...
@@ -25,10 +25,14 @@ class FC(nn.Layer):
super
(
FC
,
self
).
__init__
()
super
(
FC
,
self
).
__init__
()
self
.
embedding_size
=
embedding_size
self
.
embedding_size
=
embedding_size
self
.
class_num
=
class_num
self
.
class_num
=
class_num
# TODO: hard code for initializer
weight_attr
=
paddle
.
ParamAttr
(
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
XavierNormal
(
))
initializer
=
paddle
.
nn
.
initializer
.
Normal
(
std
=
0.001
))
self
.
fc
=
paddle
.
nn
.
Linear
(
self
.
fc
=
paddle
.
nn
.
Linear
(
self
.
embedding_size
,
self
.
class_num
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
self
.
embedding_size
,
self
.
class_num
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
def
forward
(
self
,
input
,
label
=
None
):
def
forward
(
self
,
input
,
label
=
None
):
out
=
self
.
fc
(
input
)
out
=
self
.
fc
(
input
)
...
...
ppcls/configs/Pedestrian/strong_baseline_m1.yaml
浏览文件 @
af90cd7c
...
@@ -8,12 +8,13 @@ Global:
...
@@ -8,12 +8,13 @@ Global:
eval_during_train
:
True
eval_during_train
:
True
eval_interval
:
10
eval_interval
:
10
epochs
:
120
epochs
:
120
print_batch_step
:
1
0
print_batch_step
:
2
0
use_visualdl
:
False
use_visualdl
:
False
# used for static mode and model export
# used for static mode and model export
image_shape
:
[
3
,
256
,
128
]
image_shape
:
[
3
,
256
,
128
]
save_inference_dir
:
"
./inference"
save_inference_dir
:
"
./inference"
eval_mode
:
"
retrieval"
eval_mode
:
"
retrieval"
feat_from
:
"
neck"
# 'backbone' or 'neck'
# model architecture
# model architecture
Arch
:
Arch
:
...
@@ -29,6 +30,7 @@ Arch:
...
@@ -29,6 +30,7 @@ Arch:
Neck
:
Neck
:
name
:
BNNeck
name
:
BNNeck
num_filters
:
2048
num_filters
:
2048
# trainable: False # TODO: free bn.bias
Head
:
Head
:
name
:
"
FC"
name
:
"
FC"
embedding_size
:
2048
embedding_size
:
2048
...
...
ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml
0 → 100644
浏览文件 @
af90cd7c
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
10
eval_during_train
:
True
eval_interval
:
10
epochs
:
120
print_batch_step
:
20
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
256
,
128
]
save_inference_dir
:
"
./inference"
eval_mode
:
"
retrieval"
feat_from
:
"
neck"
# 'backbone' or 'neck'
# model architecture
Arch
:
name
:
"
RecModel"
infer_output_key
:
"
features"
infer_add_softmax
:
False
Backbone
:
name
:
"
ResNet50_last_stage_stride1"
pretrained
:
True
stem_act
:
null
BackboneStopLayer
:
name
:
"
flatten"
Neck
:
name
:
BNNeck
num_filters
:
2048
# trainable: False # TODO: free bn.bias
Head
:
name
:
"
FC"
embedding_size
:
2048
class_num
:
751
bias_attr
:
false
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
-
TripletLossV2
:
weight
:
1.0
margin
:
0.3
normalize_feature
:
false
-
CenterLoss
:
weight
:
0.0005
num_classes
:
751
feat_dim
:
2048
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
model
:
name
:
Adam
lr
:
name
:
Piecewise
decay_epochs
:
[
30
,
60
]
values
:
[
0.00035
,
0.000035
,
0.0000035
]
warmup_epoch
:
10
warmup_start_lr
:
0.0000035
regularizer
:
name
:
'
L2'
coeff
:
0.0005
loss
:
name
:
SGD
lr
:
name
:
Constant
learning_rate
:
0.5
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
"
VeriWild"
image_root
:
"
./dataset/market1501"
cls_label_path
:
"
./dataset/market1501/bounding_box_train.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
RandFlipImage
:
flip_code
:
1
-
Pad
:
padding
:
10
-
RandCropImage
:
size
:
[
128
,
256
]
scale
:
[
0.8022
,
0.8022
]
ratio
:
[
0.5
,
0.5
]
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.5
sl
:
0.02
sh
:
0.4
r1
:
0.3
mean
:
[
0.4914
,
0.4822
,
0.4465
]
sampler
:
name
:
DistributedRandomIdentitySampler
batch_size
:
64
num_instances
:
4
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
6
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
"
VeriWild"
image_root
:
"
./dataset/market1501"
cls_label_path
:
"
./dataset/market1501/query.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
6
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
"
VeriWild"
image_root
:
"
./dataset/market1501"
cls_label_path
:
"
./dataset/market1501/bounding_box_test.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
6
use_shared_memory
:
True
Metric
:
Eval
:
-
Recallk
:
topk
:
[
1
,
5
]
-
mAP
:
{}
ppcls/engine/engine.py
浏览文件 @
af90cd7c
...
@@ -223,7 +223,12 @@ class Engine(object):
...
@@ -223,7 +223,12 @@ class Engine(object):
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'train'
:
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
config
[
"Optimizer"
],
self
.
config
[
"Global"
][
"epochs"
],
self
.
config
[
"Optimizer"
],
self
.
config
[
"Global"
][
"epochs"
],
len
(
self
.
train_dataloader
),
[
self
.
model
])
len
(
self
.
train_dataloader
),
[
self
.
model
,
*
[
m
for
m
in
self
.
train_loss_func
.
loss_func
if
len
(
m
.
parameters
())
>
0
]
])
# for amp training
# for amp training
if
self
.
amp
:
if
self
.
amp
:
...
@@ -251,6 +256,11 @@ class Engine(object):
...
@@ -251,6 +256,11 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"distributed"
]:
if
self
.
config
[
"Global"
][
"distributed"
]:
dist
.
init_parallel_env
()
dist
.
init_parallel_env
()
self
.
model
=
paddle
.
DataParallel
(
self
.
model
)
self
.
model
=
paddle
.
DataParallel
(
self
.
model
)
# NOTE: parallelize loss which has parameters, such as CenterLoss
for
i
in
range
(
len
(
self
.
train_loss_func
.
loss_func
)):
if
len
(
self
.
train_loss_func
.
loss_func
[
i
].
parameters
())
>
0
:
self
.
train_loss_func
.
loss_func
[
i
]
=
paddle
.
DataParallel
(
self
.
train_loss_func
.
loss_func
[
i
])
# build postprocess for infer
# build postprocess for infer
if
self
.
mode
==
'infer'
:
if
self
.
mode
==
'infer'
:
...
...
ppcls/engine/evaluation/retrieval.py
浏览文件 @
af90cd7c
...
@@ -125,7 +125,14 @@ def cal_feature(engine, name='gallery'):
...
@@ -125,7 +125,14 @@ def cal_feature(engine, name='gallery'):
out
=
engine
.
model
(
batch
[
0
],
batch
[
1
])
out
=
engine
.
model
(
batch
[
0
],
batch
[
1
])
if
"Student"
in
out
:
if
"Student"
in
out
:
out
=
out
[
"Student"
]
out
=
out
[
"Student"
]
batch_feas
=
out
[
"backbone"
]
# get features
if
engine
.
config
[
"Global"
].
get
(
"feat_from"
,
'backbone'
)
==
'backbone'
:
# use backbone's output as features
batch_feas
=
out
[
"backbone"
]
else
:
# use neck's output as features
batch_feas
=
out
[
"neck"
]
# do norm
# do norm
if
engine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
if
engine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
...
...
ppcls/engine/train/train.py
浏览文件 @
af90cd7c
...
@@ -54,16 +54,24 @@ def train_epoch(engine, epoch_id, print_batch_step):
...
@@ -54,16 +54,24 @@ def train_epoch(engine, epoch_id, print_batch_step):
out
=
forward
(
engine
,
batch
)
out
=
forward
(
engine
,
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
# step opt
and lr
# step opt
if
engine
.
amp
:
if
engine
.
amp
:
scaled
=
engine
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
=
engine
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
.
backward
()
scaled
.
backward
()
engine
.
scaler
.
minimize
(
engine
.
optimizer
,
scaled
)
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
scaler
.
minimize
(
engine
.
optimizer
[
i
],
scaled
)
else
:
else
:
loss_dict
[
"loss"
].
backward
()
loss_dict
[
"loss"
].
backward
()
engine
.
optimizer
.
step
()
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
.
clear_grad
()
engine
.
optimizer
[
i
].
step
()
engine
.
lr_sch
.
step
()
# clear grad
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
[
i
].
clear_grad
()
# step lr
for
i
in
range
(
len
(
engine
.
lr_sch
)):
engine
.
lr_sch
[
i
].
step
()
# below code just for logging
# below code just for logging
# update metric_for_logger
# update metric_for_logger
...
...
ppcls/engine/train/utils.py
浏览文件 @
af90cd7c
...
@@ -38,7 +38,12 @@ def update_loss(trainer, loss_dict, batch_size):
...
@@ -38,7 +38,12 @@ def update_loss(trainer, loss_dict, batch_size):
def
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
):
def
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
):
lr_msg
=
"lr: {:.5f}"
.
format
(
trainer
.
lr_sch
.
get_lr
())
if
len
(
trainer
.
lr_sch
)
<=
1
:
lr_msg
=
"lr: {:.8f}"
.
format
(
trainer
.
lr_sch
[
0
].
get_lr
())
else
:
lr_msg
=
"lr_model: {:.8f}"
.
format
(
trainer
.
lr_sch
[
0
].
get_lr
())
lr_msg
+=
", lr_loss: {:.8f}"
.
format
(
trainer
.
lr_sch
[
1
].
get_lr
())
metric_msg
=
", "
.
join
([
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
trainer
.
output_info
[
key
].
avg
)
"{}: {:.5f}"
.
format
(
key
,
trainer
.
output_info
[
key
].
avg
)
for
key
in
trainer
.
output_info
for
key
in
trainer
.
output_info
...
@@ -58,12 +63,23 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
...
@@ -58,12 +63,23 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
epoch_id
,
trainer
.
config
[
"Global"
][
"epochs"
],
iter_id
,
epoch_id
,
trainer
.
config
[
"Global"
][
"epochs"
],
iter_id
,
len
(
trainer
.
train_dataloader
),
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
len
(
trainer
.
train_dataloader
),
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
eta_msg
))
if
len
(
trainer
.
lr_sch
)
<=
1
:
logger
.
scaler
(
logger
.
scaler
(
name
=
"lr"
,
name
=
"lr"
,
value
=
trainer
.
lr_sch
.
get_lr
(),
value
=
trainer
.
lr_sch
[
0
].
get_lr
(),
step
=
trainer
.
global_step
,
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
writer
=
trainer
.
vdl_writer
)
else
:
logger
.
scaler
(
name
=
"lr_model"
,
value
=
trainer
.
lr_sch
[
0
].
get_lr
(),
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
logger
.
scaler
(
name
=
"lr_loss"
,
value
=
trainer
.
lr_sch
[
1
].
get_lr
(),
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
for
key
in
trainer
.
output_info
:
for
key
in
trainer
.
output_info
:
logger
.
scaler
(
logger
.
scaler
(
name
=
"train_{}"
.
format
(
key
),
name
=
"train_{}"
.
format
(
key
),
...
...
ppcls/loss/centerloss.py
浏览文件 @
af90cd7c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
typing
import
Dict
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
Tensor
class
CenterLoss
(
nn
.
Layer
):
class
CenterLoss
(
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
5013
,
feat_dim
=
2048
):
"""Center loss class
Args:
num_classes (int): number of classes.
feat_dim (int): number of feature dimensions.
"""
def
__init__
(
self
,
num_classes
:
int
,
feat_dim
:
int
):
super
(
CenterLoss
,
self
).
__init__
()
super
(
CenterLoss
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
feat_dim
=
feat_dim
self
.
feat_dim
=
feat_dim
self
.
centers
=
paddle
.
randn
(
random_init_centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
]).
astype
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
])
"float64"
)
#random center
self
.
centers
=
self
.
create_parameter
(
shape
=
(
self
.
num_classes
,
self
.
feat_dim
),
default_initializer
=
nn
.
initializer
.
Assign
(
random_init_centers
))
self
.
add_parameter
(
"centers"
,
self
.
centers
)
def
__call__
(
self
,
input
,
target
):
def
__call__
(
self
,
input
:
Dict
[
str
,
Tensor
],
"""
target
:
Tensor
)
->
Dict
[
str
,
Tensor
]:
inputs: network output: {"features: xxx", "logits": xxxx}
"""compute center loss.
target: image label
Args:
input (Dict[str, Tensor]): {'features': (batch_size, feature_dim), ...}.
target (Tensor): ground truth label with shape (batch_size, ).
Returns:
Dict[str, Tensor]: {'CenterLoss': loss}.
"""
"""
feats
=
input
[
"features"
]
feats
=
input
[
'backbone'
]
labels
=
target
labels
=
target
batch_size
=
feats
.
shape
[
0
]
#calc feat * feat
# squeeze labels to shape (batch_size, )
if
labels
.
ndim
>=
2
and
labels
.
shape
[
-
1
]
==
1
:
labels
=
paddle
.
squeeze
(
labels
,
axis
=
[
-
1
])
batch_size
=
feats
.
shape
[
0
]
# calc feat * feat
dist1
=
paddle
.
sum
(
paddle
.
square
(
feats
),
axis
=
1
,
keepdim
=
True
)
dist1
=
paddle
.
sum
(
paddle
.
square
(
feats
),
axis
=
1
,
keepdim
=
True
)
dist1
=
paddle
.
expand
(
dist1
,
[
batch_size
,
self
.
num_classes
])
dist1
=
paddle
.
expand
(
dist1
,
[
batch_size
,
self
.
num_classes
])
#dist2 of centers
#
dist2 of centers
dist2
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
dist2
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
keepdim
=
True
)
#num_classes
keepdim
=
True
)
# num_classes
dist2
=
paddle
.
expand
(
dist2
,
dist2
=
paddle
.
expand
(
dist2
,
[
self
.
num_classes
,
batch_size
])
[
self
.
num_classes
,
batch_size
]).
astype
(
"float64"
)
dist2
=
paddle
.
transpose
(
dist2
,
[
1
,
0
])
dist2
=
paddle
.
transpose
(
dist2
,
[
1
,
0
])
#first x * x + y * y
#
first x * x + y * y
distmat
=
paddle
.
add
(
dist1
,
dist2
)
distmat
=
paddle
.
add
(
dist1
,
dist2
)
tmp
=
paddle
.
matmul
(
feats
,
paddle
.
transpose
(
self
.
centers
,
[
1
,
0
]))
tmp
=
paddle
.
matmul
(
feats
,
paddle
.
transpose
(
self
.
centers
,
[
1
,
0
]))
distmat
=
distmat
-
2.0
*
tmp
distmat
=
distmat
-
2.0
*
tmp
#generate the mask
#
generate the mask
classes
=
paddle
.
arange
(
self
.
num_classes
)
.
astype
(
"int64"
)
classes
=
paddle
.
arange
(
self
.
num_classes
)
labels
=
paddle
.
expand
(
labels
=
paddle
.
expand
(
paddle
.
unsqueeze
(
labels
,
1
),
(
batch_size
,
self
.
num_classes
))
paddle
.
unsqueeze
(
labels
,
1
),
(
batch_size
,
self
.
num_classes
))
mask
=
paddle
.
equal
(
mask
=
paddle
.
equal
(
paddle
.
expand
(
classes
,
[
batch_size
,
self
.
num_classes
]),
paddle
.
expand
(
classes
,
[
batch_size
,
self
.
num_classes
]),
labels
).
astype
(
"float
64"
)
#
get mask
labels
).
astype
(
"float
32"
)
#
get mask
dist
=
paddle
.
multiply
(
distmat
,
mask
)
dist
=
paddle
.
multiply
(
distmat
,
mask
)
loss
=
paddle
.
sum
(
paddle
.
clip
(
dist
,
min
=
1e-12
,
max
=
1e+12
))
/
batch_size
loss
=
paddle
.
sum
(
paddle
.
clip
(
dist
,
min
=
1e-12
,
max
=
1e+12
))
/
batch_size
# return loss
return
{
'CenterLoss'
:
loss
}
return
{
'CenterLoss'
:
loss
}
ppcls/optimizer/__init__.py
浏览文件 @
af90cd7c
...
@@ -44,29 +44,97 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
...
@@ -44,29 +44,97 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
# model_list is None in static graph
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model_list
=
None
):
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model_list
=
None
):
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
# step1 build lr
if
'name'
in
config
:
lr
=
build_lr_scheduler
(
config
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
# NOTE: build optimizer and lr for model only.
logger
.
debug
(
"build lr ({}) success.."
.
format
(
lr
))
# step1 build lr
# step2 build regularization
lr
=
build_lr_scheduler
(
config
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
logger
.
debug
(
"build model's lr ({}) success.."
.
format
(
lr
))
if
'weight_decay'
in
config
:
# step2 build regularization
logger
.
warning
(
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
if
'weight_decay'
in
config
:
)
logger
.
warning
(
reg_config
=
config
.
pop
(
'regularizer'
)
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
)
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
reg_config
=
config
.
pop
(
'regularizer'
)
config
[
"weight_decay"
]
=
reg
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
logger
.
debug
(
"build regularizer ({}) success.."
.
format
(
reg
))
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
# step3 build optimizer
config
[
"weight_decay"
]
=
reg
optim_name
=
config
.
pop
(
'name'
)
logger
.
debug
(
"build model's regularizer ({}) success.."
.
format
(
if
'clip_norm'
in
config
:
reg
))
clip_norm
=
config
.
pop
(
'clip_norm'
)
# step3 build optimizer
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
optim_name
=
config
.
pop
(
'name'
)
if
'clip_norm'
in
config
:
clip_norm
=
config
.
pop
(
'clip_norm'
)
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
else
:
grad_clip
=
None
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
grad_clip
=
grad_clip
,
**
config
)(
model_list
=
model_list
[
0
:
1
])
optim
=
[
optim
,
]
lr
=
[
lr
,
]
logger
.
debug
(
"build model's optimizer ({}) success.."
.
format
(
optim
))
else
:
else
:
grad_clip
=
None
# NOTE: build optimizer and lr for model and loss.
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
config_model
=
config
[
'model'
]
grad_clip
=
grad_clip
,
config_loss
=
config
[
'loss'
]
**
config
)(
model_list
=
model_list
)
# step1 build lr
logger
.
debug
(
"build optimizer ({}) success.."
.
format
(
optim
))
lr_model
=
build_lr_scheduler
(
config_model
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
logger
.
debug
(
"build model's lr ({}) success.."
.
format
(
lr_model
))
# step2 build regularization
if
'regularizer'
in
config_model
and
config_model
[
'regularizer'
]
is
not
None
:
if
'weight_decay'
in
config_model
:
logger
.
warning
(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
)
reg_config
=
config_model
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg_model
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
config_model
[
"weight_decay"
]
=
reg_model
logger
.
debug
(
"build model's regularizer ({}) success.."
.
format
(
reg_model
))
# step3 build optimizer
optim_name
=
config_model
.
pop
(
'name'
)
if
'clip_norm'
in
config_model
:
clip_norm
=
config_model
.
pop
(
'clip_norm'
)
grad_clip_model
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
else
:
grad_clip_model
=
None
optim_model
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr_model
,
grad_clip
=
grad_clip_model
,
**
config_model
)(
model_list
=
model_list
[
0
:
1
])
# step4 build lr for loss
lr_loss
=
build_lr_scheduler
(
config_loss
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
logger
.
debug
(
"build loss's lr ({}) success.."
.
format
(
lr_loss
))
# step5 build regularization for loss
if
'regularizer'
in
config_loss
and
config_loss
[
'regularizer'
]
is
not
None
:
if
'weight_decay'
in
config_loss
:
logger
.
warning
(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
)
reg_config
=
config_loss
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg_loss
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
config_loss
[
"weight_decay"
]
=
reg_loss
logger
.
debug
(
"build loss's regularizer ({}) success.."
.
format
(
reg_loss
))
# step6 build optimizer for loss
optim_name
=
config_loss
.
pop
(
'name'
)
if
'clip_norm'
in
config_loss
:
clip_norm
=
config_loss
.
pop
(
'clip_norm'
)
grad_clip_loss
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
else
:
grad_clip_loss
=
None
optim_loss
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr_loss
,
grad_clip
=
grad_clip_loss
,
**
config_loss
)(
model_list
=
model_list
[
1
:
2
])
optim
=
[
optim_model
,
optim_loss
]
lr
=
[
lr_model
,
lr_loss
]
logger
.
debug
(
"build loss's optimizer ({}) success.."
.
format
(
optim
))
return
optim
,
lr
return
optim
,
lr
ppcls/optimizer/learning_rate.py
浏览文件 @
af90cd7c
...
@@ -75,6 +75,23 @@ class Linear(object):
...
@@ -75,6 +75,23 @@ class Linear(object):
return
learning_rate
return
learning_rate
class
Constant
(
LRScheduler
):
"""
Constant learning rate
Args:
lr (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def
__init__
(
self
,
learning_rate
,
last_epoch
=-
1
,
**
kwargs
):
self
.
learning_rate
=
learning_rate
self
.
last_epoch
=
last_epoch
super
().
__init__
()
def
get_lr
(
self
):
return
self
.
learning_rate
class
Cosine
(
object
):
class
Cosine
(
object
):
"""
"""
Cosine learning rate decay
Cosine learning rate decay
...
...
ppcls/utils/save_load.py
浏览文件 @
af90cd7c
...
@@ -48,7 +48,7 @@ def _mkdir_if_not_exist(path):
...
@@ -48,7 +48,7 @@ def _mkdir_if_not_exist(path):
def
load_dygraph_pretrain
(
model
,
path
=
None
):
def
load_dygraph_pretrain
(
model
,
path
=
None
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
"exists."
.
format
(
path
+
'.pdparams'
))
param_state_dict
=
paddle
.
load
(
path
+
".pdparams"
)
param_state_dict
=
paddle
.
load
(
path
+
".pdparams"
)
model
.
set_dict
(
param_state_dict
)
model
.
set_dict
(
param_state_dict
)
return
return
...
@@ -99,7 +99,8 @@ def init_model(config, net, optimizer=None):
...
@@ -99,7 +99,8 @@ def init_model(config, net, optimizer=None):
opti_dict
=
paddle
.
load
(
checkpoints
+
".pdopt"
)
opti_dict
=
paddle
.
load
(
checkpoints
+
".pdopt"
)
metric_dict
=
paddle
.
load
(
checkpoints
+
".pdstates"
)
metric_dict
=
paddle
.
load
(
checkpoints
+
".pdstates"
)
net
.
set_dict
(
para_dict
)
net
.
set_dict
(
para_dict
)
optimizer
.
set_state_dict
(
opti_dict
)
for
i
in
range
(
len
(
optimizer
)):
optimizer
[
i
].
set_state_dict
(
opti_dict
)
logger
.
info
(
"Finish load checkpoints from {}"
.
format
(
checkpoints
))
logger
.
info
(
"Finish load checkpoints from {}"
.
format
(
checkpoints
))
return
metric_dict
return
metric_dict
...
@@ -131,6 +132,6 @@ def save_model(net,
...
@@ -131,6 +132,6 @@ def save_model(net,
model_path
=
os
.
path
.
join
(
model_path
,
prefix
)
model_path
=
os
.
path
.
join
(
model_path
,
prefix
)
paddle
.
save
(
net
.
state_dict
(),
model_path
+
".pdparams"
)
paddle
.
save
(
net
.
state_dict
(),
model_path
+
".pdparams"
)
paddle
.
save
(
optimizer
.
state_dict
()
,
model_path
+
".pdopt"
)
paddle
.
save
(
[
opt
.
state_dict
()
for
opt
in
optimizer
]
,
model_path
+
".pdopt"
)
paddle
.
save
(
metric_info
,
model_path
+
".pdstates"
)
paddle
.
save
(
metric_info
,
model_path
+
".pdstates"
)
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录