Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
0d7e595f
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0d7e595f
编写于
2月 27, 2023
作者:
G
gaotingquan
提交者:
Wei Shengyu
3月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
mv model_saver to __init__()
上级
6e77bd6c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
89 addition
and
94 deletion
+89
-94
ppcls/engine/engine.py
ppcls/engine/engine.py
+13
-14
ppcls/utils/model_saver.py
ppcls/utils/model_saver.py
+0
-80
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+76
-0
未找到文件。
ppcls/engine/engine.py
浏览文件 @
0d7e595f
...
...
@@ -33,8 +33,7 @@ from ppcls.metric import build_metrics
from
ppcls.optimizer
import
build_optimizer
from
ppcls.utils.ema
import
ExponentialMovingAverage
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
from
ppcls.utils.save_load
import
init_model
from
ppcls.utils.model_saver
import
ModelSaver
from
ppcls.utils.save_load
import
init_model
,
ModelSaver
from
ppcls.data.utils.get_image_list
import
get_image_list
from
ppcls.data.postprocess
import
build_postprocess
...
...
@@ -100,6 +99,14 @@ class Engine(object):
# for distributed
self
.
_init_dist
()
# build model saver
self
.
model_saver
=
ModelSaver
(
self
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
)
print_config
(
config
)
def
train
(
self
):
...
...
@@ -129,14 +136,6 @@ class Engine(object):
# TODO: mv best_metric_ema to best_metric dict
best_metric_ema
=
0
# build model saver
model_saver
=
ModelSaver
(
self
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
)
self
.
_init_checkpoints
(
best_metric
)
# global iter counter
...
...
@@ -166,7 +165,7 @@ class Engine(object):
if
acc
>
best_metric
[
"metric"
]:
best_metric
[
"metric"
]
=
acc
best_metric
[
"epoch"
]
=
epoch_id
model_saver
.
save
(
self
.
model_saver
.
save
(
best_metric
,
prefix
=
"best_model"
,
save_student_model
=
True
)
...
...
@@ -189,7 +188,7 @@ class Engine(object):
if
acc_ema
>
best_metric_ema
:
best_metric_ema
=
acc_ema
model_saver
.
save
(
self
.
model_saver
.
save
(
{
"metric"
:
acc_ema
,
"epoch"
:
epoch_id
...
...
@@ -205,7 +204,7 @@ class Engine(object):
# save model
if
save_interval
>
0
and
epoch_id
%
save_interval
==
0
:
model_saver
.
save
(
self
.
model_saver
.
save
(
{
"metric"
:
acc
,
"epoch"
:
epoch_id
...
...
@@ -213,7 +212,7 @@ class Engine(object):
prefix
=
f
"epoch_
{
epoch_id
}
"
)
# save the latest model
model_saver
.
save
(
self
.
model_saver
.
save
(
{
"metric"
:
acc
,
"epoch"
:
epoch_id
...
...
ppcls/utils/model_saver.py
已删除
100644 → 0
浏览文件 @
6e77bd6c
import
os
import
paddle
from
.
import
logger
def
_mkdir_if_not_exist
(
path
):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if
not
os
.
path
.
exists
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
logger
.
warning
(
'be happy if some process has already created {}'
.
format
(
path
))
else
:
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
_extract_student_weights
(
all_params
,
student_prefix
=
"Student."
):
s_params
=
{
key
[
len
(
student_prefix
):]:
all_params
[
key
]
for
key
in
all_params
if
student_prefix
in
key
}
return
s_params
class
ModelSaver
(
object
):
def
__init__
(
self
,
engine
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
):
# net, loss, opt, model_ema, output_dir,
self
.
engine
=
engine
self
.
net_name
=
net_name
self
.
loss_name
=
loss_name
self
.
opt_name
=
opt_name
self
.
model_ema_name
=
model_ema_name
arch_name
=
engine
.
config
[
"Arch"
][
"name"
]
self
.
output_dir
=
os
.
path
.
join
(
engine
.
output_dir
,
arch_name
)
_mkdir_if_not_exist
(
self
.
output_dir
)
def
save
(
self
,
metric_info
,
prefix
=
'ppcls'
,
save_student_model
=
False
):
if
paddle
.
distributed
.
get_rank
()
!=
0
:
return
save_dir
=
os
.
path
.
join
(
self
.
output_dir
,
prefix
)
params_state_dict
=
getattr
(
self
.
engine
,
self
.
net_name
).
state_dict
()
loss
=
getattr
(
self
.
engine
,
self
.
loss_name
)
if
loss
is
not
None
:
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
())
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
if
save_student_model
:
s_params
=
_extract_student_weights
(
params_state_dict
)
if
len
(
s_params
)
>
0
:
paddle
.
save
(
s_params
,
save_dir
+
"_student.pdparams"
)
paddle
.
save
(
params_state_dict
,
save_dir
+
".pdparams"
)
model_ema
=
getattr
(
self
.
engine
,
self
.
model_ema_name
)
if
model_ema
is
not
None
:
paddle
.
save
(
model_ema
.
module
.
state_dict
(),
save_dir
+
".ema.pdparams"
)
optimizer
=
getattr
(
self
.
engine
,
self
.
opt_name
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
save_dir
+
".pdopt"
)
paddle
.
save
(
metric_info
,
save_dir
+
".pdstates"
)
logger
.
info
(
"Already save model in {}"
.
format
(
save_dir
))
ppcls/utils/save_load.py
浏览文件 @
0d7e595f
...
...
@@ -123,3 +123,79 @@ def init_model(config,
load_dygraph_pretrain
(
net
,
path
=
pretrained_model
)
logger
.
info
(
"Finish load pretrained model from {}"
.
format
(
pretrained_model
))
def
_mkdir_if_not_exist
(
path
):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if
not
os
.
path
.
exists
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
logger
.
warning
(
'be happy if some process has already created {}'
.
format
(
path
))
else
:
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
_extract_student_weights
(
all_params
,
student_prefix
=
"Student."
):
s_params
=
{
key
[
len
(
student_prefix
):]:
all_params
[
key
]
for
key
in
all_params
if
student_prefix
in
key
}
return
s_params
class
ModelSaver
(
object
):
def
__init__
(
self
,
engine
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
):
# net, loss, opt, model_ema, output_dir,
self
.
engine
=
engine
self
.
net_name
=
net_name
self
.
loss_name
=
loss_name
self
.
opt_name
=
opt_name
self
.
model_ema_name
=
model_ema_name
arch_name
=
engine
.
config
[
"Arch"
][
"name"
]
self
.
output_dir
=
os
.
path
.
join
(
engine
.
output_dir
,
arch_name
)
_mkdir_if_not_exist
(
self
.
output_dir
)
def
save
(
self
,
metric_info
,
prefix
=
'ppcls'
,
save_student_model
=
False
):
if
paddle
.
distributed
.
get_rank
()
!=
0
:
return
save_dir
=
os
.
path
.
join
(
self
.
output_dir
,
prefix
)
params_state_dict
=
getattr
(
self
.
engine
,
self
.
net_name
).
state_dict
()
loss
=
getattr
(
self
.
engine
,
self
.
loss_name
)
if
loss
is
not
None
:
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
())
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
if
save_student_model
:
s_params
=
_extract_student_weights
(
params_state_dict
)
if
len
(
s_params
)
>
0
:
paddle
.
save
(
s_params
,
save_dir
+
"_student.pdparams"
)
paddle
.
save
(
params_state_dict
,
save_dir
+
".pdparams"
)
model_ema
=
getattr
(
self
.
engine
,
self
.
model_ema_name
)
if
model_ema
is
not
None
:
paddle
.
save
(
model_ema
.
module
.
state_dict
(),
save_dir
+
".ema.pdparams"
)
optimizer
=
getattr
(
self
.
engine
,
self
.
opt_name
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
save_dir
+
".pdopt"
)
paddle
.
save
(
metric_info
,
save_dir
+
".pdstates"
)
logger
.
info
(
"Already save model in {}"
.
format
(
save_dir
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录