Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
43a03a0c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
43a03a0c
编写于
4月 20, 2022
作者:
W
Walter
提交者:
GitHub
4月 20, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1853 from HydrogenSulfate/multi_optim
support for multi optimizer case
上级
42cb6435
15242df1
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
149 addition
and
60 deletion
+149
-60
ppcls/engine/engine.py
ppcls/engine/engine.py
+25
-11
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+11
-5
ppcls/engine/train/utils.py
ppcls/engine/train/utils.py
+10
-6
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+76
-26
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+26
-12
未找到文件。
ppcls/engine/engine.py
浏览文件 @
43a03a0c
...
...
@@ -214,16 +214,19 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
if
self
.
config
[
"Global"
][
"pretrained_model"
].
startswith
(
"http"
):
load_dygraph_pretrain_from_url
(
self
.
model
,
self
.
config
[
"Global"
][
"pretrained_model"
])
[
self
.
model
,
getattr
(
self
,
'train_loss_func'
,
None
)],
self
.
config
[
"Global"
][
"pretrained_model"
])
else
:
load_dygraph_pretrain
(
self
.
model
,
self
.
config
[
"Global"
][
"pretrained_model"
])
[
self
.
model
,
getattr
(
self
,
'train_loss_func'
,
None
)],
self
.
config
[
"Global"
][
"pretrained_model"
])
# build optimizer
if
self
.
mode
==
'train'
:
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
config
[
"Optimizer"
],
self
.
config
[
"Global"
][
"epochs"
],
len
(
self
.
train_dataloader
),
[
self
.
model
])
self
.
config
,
self
.
config
[
"Global"
][
"epochs"
],
len
(
self
.
train_dataloader
),
[
self
.
model
,
self
.
train_loss_func
])
# for amp training
if
self
.
amp
:
...
...
@@ -241,6 +244,11 @@ class Engine(object):
optimizers
=
self
.
optimizer
,
level
=
amp_level
,
save_dtype
=
'float32'
)
if
len
(
self
.
train_loss_func
.
parameters
())
>
0
:
self
.
train_loss_func
=
paddle
.
amp
.
decorate
(
models
=
self
.
train_loss_func
,
level
=
amp_level
,
save_dtype
=
'float32'
)
# for distributed
world_size
=
dist
.
get_world_size
()
...
...
@@ -251,7 +259,10 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"distributed"
]:
dist
.
init_parallel_env
()
self
.
model
=
paddle
.
DataParallel
(
self
.
model
)
if
self
.
mode
==
'train'
and
len
(
self
.
train_loss_func
.
parameters
(
))
>
0
:
self
.
train_loss_func
=
paddle
.
DataParallel
(
self
.
train_loss_func
)
# build postprocess for infer
if
self
.
mode
==
'infer'
:
self
.
preprocess_func
=
create_operators
(
self
.
config
[
"Infer"
][
...
...
@@ -279,9 +290,9 @@ class Engine(object):
# global iter counter
self
.
global_step
=
0
if
self
.
config
[
"Global"
][
"checkpoints"
]
is
not
None
:
metric_info
=
init_model
(
self
.
config
[
"Global"
]
,
self
.
model
,
self
.
optimizer
)
if
self
.
config
.
Global
.
checkpoints
is
not
None
:
metric_info
=
init_model
(
self
.
config
.
Global
,
self
.
model
,
self
.
optimizer
,
self
.
train_loss_func
)
if
metric_info
is
not
None
:
best_metric
.
update
(
metric_info
)
...
...
@@ -317,7 +328,8 @@ class Engine(object):
best_metric
,
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"best_model"
)
prefix
=
"best_model"
,
loss
=
self
.
train_loss_func
)
logger
.
info
(
"[Eval][Epoch {}][best metric: {}]"
.
format
(
epoch_id
,
best_metric
[
"metric"
]))
logger
.
scaler
(
...
...
@@ -336,7 +348,8 @@ class Engine(object):
"epoch"
:
epoch_id
},
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"epoch_{}"
.
format
(
epoch_id
))
prefix
=
"epoch_{}"
.
format
(
epoch_id
),
loss
=
self
.
train_loss_func
)
# save the latest model
save_load
.
save_model
(
self
.
model
,
...
...
@@ -344,7 +357,8 @@ class Engine(object):
"epoch"
:
epoch_id
},
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"latest"
)
prefix
=
"latest"
,
loss
=
self
.
train_loss_func
)
if
self
.
vdl_writer
is
not
None
:
self
.
vdl_writer
.
close
()
...
...
ppcls/engine/train/train.py
浏览文件 @
43a03a0c
...
...
@@ -53,16 +53,22 @@ def train_epoch(engine, epoch_id, print_batch_step):
out
=
forward
(
engine
,
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
# step opt
and lr
# step opt
if
engine
.
amp
:
scaled
=
engine
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
.
backward
()
engine
.
scaler
.
minimize
(
engine
.
optimizer
,
scaled
)
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
scaler
.
minimize
(
engine
.
optimizer
[
i
],
scaled
)
else
:
loss_dict
[
"loss"
].
backward
()
engine
.
optimizer
.
step
()
engine
.
optimizer
.
clear_grad
()
engine
.
lr_sch
.
step
()
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
[
i
].
step
()
# clear grad
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
[
i
].
clear_grad
()
# step lr
for
i
in
range
(
len
(
engine
.
lr_sch
)):
engine
.
lr_sch
[
i
].
step
()
# below code just for logging
# update metric_for_logger
...
...
ppcls/engine/train/utils.py
浏览文件 @
43a03a0c
...
...
@@ -38,7 +38,10 @@ def update_loss(trainer, loss_dict, batch_size):
def
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
):
lr_msg
=
"lr: {:.5f}"
.
format
(
trainer
.
lr_sch
.
get_lr
())
lr_msg
=
", "
.
join
([
"lr_{}: {:.8f}"
.
format
(
i
+
1
,
lr
.
get_lr
())
for
i
,
lr
in
enumerate
(
trainer
.
lr_sch
)
])
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
trainer
.
output_info
[
key
].
avg
)
for
key
in
trainer
.
output_info
...
...
@@ -59,11 +62,12 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
len
(
trainer
.
train_dataloader
),
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
logger
.
scaler
(
name
=
"lr"
,
value
=
trainer
.
lr_sch
.
get_lr
(),
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
for
i
,
lr
in
enumerate
(
trainer
.
lr_sch
):
logger
.
scaler
(
name
=
"lr_{}"
.
format
(
i
+
1
),
value
=
lr
.
get_lr
(),
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
for
key
in
trainer
.
output_info
:
logger
.
scaler
(
name
=
"train_{}"
.
format
(
key
),
...
...
ppcls/loss/__init__.py
浏览文件 @
43a03a0c
...
...
@@ -47,6 +47,7 @@ class CombinedLoss(nn.Layer):
param
.
keys
())
self
.
loss_weight
.
append
(
param
.
pop
(
"weight"
))
self
.
loss_func
.
append
(
eval
(
name
)(
**
param
))
self
.
loss_func
=
nn
.
LayerList
(
self
.
loss_func
)
def
__call__
(
self
,
input
,
batch
):
loss_dict
=
{}
...
...
ppcls/optimizer/__init__.py
浏览文件 @
43a03a0c
...
...
@@ -18,6 +18,7 @@ from __future__ import print_function
import
copy
import
paddle
from
typing
import
Dict
,
List
from
ppcls.utils
import
logger
...
...
@@ -44,29 +45,78 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model_list
=
None
):
config
=
copy
.
deepcopy
(
config
)
# step1 build lr
lr
=
build_lr_scheduler
(
config
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
logger
.
debug
(
"build lr ({}) success.."
.
format
(
lr
))
# step2 build regularization
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
if
'weight_decay'
in
config
:
logger
.
warning
(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
)
reg_config
=
config
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
config
[
"weight_decay"
]
=
reg
logger
.
debug
(
"build regularizer ({}) success.."
.
format
(
reg
))
# step3 build optimizer
optim_name
=
config
.
pop
(
'name'
)
if
'clip_norm'
in
config
:
clip_norm
=
config
.
pop
(
'clip_norm'
)
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
else
:
grad_clip
=
None
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
grad_clip
=
grad_clip
,
**
config
)(
model_list
=
model_list
)
logger
.
debug
(
"build optimizer ({}) success.."
.
format
(
optim
))
return
optim
,
lr
optim_config
=
config
[
"Optimizer"
]
if
isinstance
(
optim_config
,
dict
):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name
=
optim_config
.
pop
(
"name"
)
optim_config
:
List
[
Dict
[
str
,
Dict
]]
=
[{
optim_name
:
{
'scope'
:
"all"
,
**
optim_config
}
}]
optim_list
=
[]
lr_list
=
[]
"""NOTE:
Currently only support optim objets below.
1. single optimizer config.
2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
3. loss which has parameters, such as CenterLoss.
"""
for
optim_item
in
optim_config
:
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
# step1 build lr
optim_name
=
list
(
optim_item
.
keys
())[
0
]
# get optim_name
optim_scope
=
optim_item
[
optim_name
].
pop
(
'scope'
)
# get optim_scope
optim_cfg
=
optim_item
[
optim_name
]
# get optim_cfg
lr
=
build_lr_scheduler
(
optim_cfg
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
logger
.
debug
(
"build lr ({}) for scope ({}) success.."
.
format
(
lr
,
optim_scope
))
# step2 build regularization
if
'regularizer'
in
optim_cfg
and
optim_cfg
[
'regularizer'
]
is
not
None
:
if
'weight_decay'
in
optim_cfg
:
logger
.
warning
(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
)
reg_config
=
optim_cfg
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
optim_cfg
[
"weight_decay"
]
=
reg
logger
.
debug
(
"build regularizer ({}) for scope ({}) success.."
.
format
(
reg
,
optim_scope
))
# step3 build optimizer
if
'clip_norm'
in
optim_cfg
:
clip_norm
=
optim_cfg
.
pop
(
'clip_norm'
)
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
else
:
grad_clip
=
None
optim_model
=
[]
for
i
in
range
(
len
(
model_list
)):
if
len
(
model_list
[
i
].
parameters
())
==
0
:
continue
if
optim_scope
==
"all"
:
# optimizer for all
optim_model
.
append
(
model_list
[
i
])
else
:
if
optim_scope
.
endswith
(
"Loss"
):
# optimizer for loss
for
m
in
model_list
[
i
].
sublayers
(
True
):
if
m
.
__class_name
==
optim_scope
:
optim_model
.
append
(
m
)
else
:
# opmizer for module in model, such as backbone, neck, head...
if
hasattr
(
model_list
[
i
],
optim_scope
):
optim_model
.
append
(
getattr
(
model_list
[
i
],
optim_scope
))
assert
len
(
optim_model
)
==
1
,
\
"Invalid optim model for optim scope({}), number of optim_model={}"
.
format
(
optim_scope
,
len
(
optim_model
))
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
grad_clip
=
grad_clip
,
**
optim_cfg
)(
model_list
=
optim_model
)
logger
.
debug
(
"build optimizer ({}) for scope ({}) success.."
.
format
(
optim
,
optim_scope
))
optim_list
.
append
(
optim
)
lr_list
.
append
(
lr
)
return
optim_list
,
lr_list
ppcls/utils/save_load.py
浏览文件 @
43a03a0c
...
...
@@ -18,9 +18,6 @@ from __future__ import print_function
import
errno
import
os
import
re
import
shutil
import
tempfile
import
paddle
from
ppcls.utils
import
logger
...
...
@@ -47,10 +44,15 @@ def _mkdir_if_not_exist(path):
def
load_dygraph_pretrain
(
model
,
path
=
None
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
raise
ValueError
(
"Model pretrain path {}
.pdparams
does not "
"exists."
.
format
(
path
))
param_state_dict
=
paddle
.
load
(
path
+
".pdparams"
)
model
.
set_dict
(
param_state_dict
)
if
isinstance
(
model
,
list
):
for
m
in
model
:
if
hasattr
(
m
,
'set_dict'
):
m
.
set_dict
(
param_state_dict
)
else
:
model
.
set_dict
(
param_state_dict
)
return
...
...
@@ -85,7 +87,7 @@ def load_distillation_model(model, pretrained_model):
pretrained_model
))
def
init_model
(
config
,
net
,
optimizer
=
None
):
def
init_model
(
config
,
net
,
optimizer
=
None
,
loss
:
paddle
.
nn
.
Layer
=
None
):
"""
load model from checkpoint or pretrained_model
"""
...
...
@@ -95,11 +97,15 @@ def init_model(config, net, optimizer=None):
"Given dir {}.pdparams not exist."
.
format
(
checkpoints
)
assert
os
.
path
.
exists
(
checkpoints
+
".pdopt"
),
\
"Given dir {}.pdopt not exist."
.
format
(
checkpoints
)
para_dict
=
paddle
.
load
(
checkpoints
+
".pdparams"
)
# load state dict
opti_dict
=
paddle
.
load
(
checkpoints
+
".pdopt"
)
para_dict
=
paddle
.
load
(
checkpoints
+
".pdparams"
)
metric_dict
=
paddle
.
load
(
checkpoints
+
".pdstates"
)
net
.
set_dict
(
para_dict
)
optimizer
.
set_state_dict
(
opti_dict
)
# set state dict
net
.
set_state_dict
(
para_dict
)
loss
.
set_state_dict
(
para_dict
)
for
i
in
range
(
len
(
optimizer
)):
optimizer
[
i
].
set_state_dict
(
opti_dict
)
logger
.
info
(
"Finish load checkpoints from {}"
.
format
(
checkpoints
))
return
metric_dict
...
...
@@ -120,7 +126,8 @@ def save_model(net,
metric_info
,
model_path
,
model_name
=
""
,
prefix
=
'ppcls'
):
prefix
=
'ppcls'
,
loss
:
paddle
.
nn
.
Layer
=
None
):
"""
save model to the target path
"""
...
...
@@ -130,7 +137,14 @@ def save_model(net,
_mkdir_if_not_exist
(
model_path
)
model_path
=
os
.
path
.
join
(
model_path
,
prefix
)
paddle
.
save
(
net
.
state_dict
(),
model_path
+
".pdparams"
)
paddle
.
save
(
optimizer
.
state_dict
(),
model_path
+
".pdopt"
)
params_state_dict
=
net
.
state_dict
()
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
())
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
paddle
.
save
(
params_state_dict
,
model_path
+
".pdparams"
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
model_path
+
".pdopt"
)
paddle
.
save
(
metric_info
,
model_path
+
".pdstates"
)
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录