Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
5a4ee1ae
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5a4ee1ae
编写于
3月 14, 2023
作者:
T
Tingquan Gao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "refactor"
This reverts commit
0e28a39d
.
上级
502aead9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
12 addition
and
88 deletion
+12
-88
ppcls/engine/engine.py
ppcls/engine/engine.py
+0
-2
ppcls/engine/train/regular_train_epoch.py
ppcls/engine/train/regular_train_epoch.py
+12
-6
ppcls/utils/model_saver.py
ppcls/utils/model_saver.py
+0
-80
未找到文件。
ppcls/engine/engine.py
浏览文件 @
5a4ee1ae
...
...
@@ -447,8 +447,6 @@ class Engine(object):
level
=
self
.
amp_level
,
save_dtype
=
'float32'
)
self
.
amp_level
=
engine
.
config
[
"AMP"
].
get
(
"level"
,
"O1"
).
upper
()
def
_init_dist
(
self
):
# check the gpu num
world_size
=
dist
.
get_world_size
()
...
...
ppcls/engine/train/regular_train_epoch.py
浏览文件 @
5a4ee1ae
...
...
@@ -36,25 +36,31 @@ def regular_train_epoch(engine, epoch_id, print_batch_step):
batch
[
1
]
=
batch
[
1
].
reshape
([
batch_size
,
-
1
])
engine
.
global_step
+=
1
#
forward & backward & step op
t
#
image inpu
t
if
engine
.
amp
:
amp_level
=
engine
.
config
[
"AMP"
].
get
(
"level"
,
"O1"
).
upper
()
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
},
level
=
engine
.
amp_level
):
level
=
amp_level
):
out
=
engine
.
model
(
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
loss
=
loss_dict
[
"loss"
]
/
engine
.
update_freq
else
:
out
=
engine
.
model
(
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
# loss
loss
=
loss_dict
[
"loss"
]
/
engine
.
update_freq
# backward & step opt
if
engine
.
amp
:
scaled
=
engine
.
scaler
.
scale
(
loss
)
scaled
.
backward
()
if
(
iter_id
+
1
)
%
engine
.
update_freq
==
0
:
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
scaler
.
minimize
(
engine
.
optimizer
[
i
],
scaled
)
else
:
out
=
engine
.
model
(
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
loss
=
loss_dict
[
"loss"
]
/
engine
.
update_freq
loss
.
backward
()
if
(
iter_id
+
1
)
%
engine
.
update_freq
==
0
:
for
i
in
range
(
len
(
engine
.
optimizer
)):
...
...
ppcls/utils/model_saver.py
已删除
100644 → 0
浏览文件 @
502aead9
import
os
import
paddle
from
.
import
logger
def
_mkdir_if_not_exist
(
path
):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if
not
os
.
path
.
exists
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
logger
.
warning
(
'be happy if some process has already created {}'
.
format
(
path
))
else
:
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
_extract_student_weights
(
all_params
,
student_prefix
=
"Student."
):
s_params
=
{
key
[
len
(
student_prefix
):]:
all_params
[
key
]
for
key
in
all_params
if
student_prefix
in
key
}
return
s_params
class
ModelSaver
(
object
):
def
__init__
(
self
,
engine
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
):
# net, loss, opt, model_ema, output_dir,
self
.
engine
=
engine
self
.
net_name
=
net_name
self
.
loss_name
=
loss_name
self
.
opt_name
=
opt_name
self
.
model_ema_name
=
model_ema_name
arch_name
=
engine
.
config
[
"Arch"
][
"name"
]
self
.
output_dir
=
os
.
path
.
join
(
engine
.
output_dir
,
arch_name
)
_mkdir_if_not_exist
(
self
.
output_dir
)
def
save
(
self
,
metric_info
,
prefix
=
'ppcls'
,
save_student_model
=
False
):
if
paddle
.
distributed
.
get_rank
()
!=
0
:
return
save_dir
=
os
.
path
.
join
(
self
.
output_dir
,
prefix
)
params_state_dict
=
getattr
(
self
.
engine
,
self
.
net_name
).
state_dict
()
loss
=
getattr
(
self
.
engine
,
self
.
loss_name
)
if
loss
is
not
None
:
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
())
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
if
save_student_model
:
s_params
=
_extract_student_weights
(
params_state_dict
)
if
len
(
s_params
)
>
0
:
paddle
.
save
(
s_params
,
save_dir
+
"_student.pdparams"
)
paddle
.
save
(
params_state_dict
,
save_dir
+
".pdparams"
)
model_ema
=
getattr
(
self
.
engine
,
self
.
model_ema_name
)
if
model_ema
is
not
None
:
paddle
.
save
(
model_ema
.
module
.
state_dict
(),
save_dir
+
".ema.pdparams"
)
optimizer
=
getattr
(
self
.
engine
,
self
.
opt_name
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
save_dir
+
".pdopt"
)
paddle
.
save
(
metric_info
,
save_dir
+
".pdstates"
)
logger
.
info
(
"Already save model in {}"
.
format
(
save_dir
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录