Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
6d92417e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d92417e
编写于
6月 29, 2021
作者:
H
Haoxin Ma
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize the function
上级
16210c05
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
32 addition
and
82 deletion
+32
-82
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+3
-2
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+29
-80
未找到文件。
deepspeech/training/trainer.py
浏览文件 @
6d92417e
...
...
@@ -151,11 +151,12 @@ class Trainer():
resume training.
"""
scratch
=
None
infos
=
self
.
checkpoint
.
load_last
_parameters
(
infos
=
self
.
checkpoint
.
_load
_parameters
(
self
.
model
,
self
.
optimizer
,
checkpoint_dir
=
self
.
checkpoint_dir
,
checkpoint_path
=
self
.
args
.
checkpoint_path
)
checkpoint_path
=
self
.
args
.
checkpoint_path
,
checkpoint_file
=
'checkpoint_latest'
)
if
infos
:
# restore from ckpt
self
.
iteration
=
infos
[
"step"
]
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
6d92417e
...
...
@@ -39,8 +39,8 @@ class Checkpoint(object):
self
.
latest_n
=
latest_n
self
.
_save_all
=
(
kbest_n
==
-
1
)
def
should_save_best
(
self
,
metric
:
float
)
->
bool
:
if
not
self
.
best_full
():
def
_
should_save_best
(
self
,
metric
:
float
)
->
bool
:
if
not
self
.
_
best_full
():
return
True
# already full
...
...
@@ -49,10 +49,10 @@ class Checkpoint(object):
worst_metric
=
self
.
best_records
[
worst_record_path
]
return
metric
<
worst_metric
def
best_full
(
self
):
def
_
best_full
(
self
):
return
(
not
self
.
_save_all
)
and
len
(
self
.
best_records
)
==
self
.
kbest_n
def
latest_full
(
self
):
def
_
latest_full
(
self
):
return
len
(
self
.
latest_records
)
==
self
.
latest_n
def
add_checkpoint
(
self
,
...
...
@@ -63,62 +63,62 @@ class Checkpoint(object):
infos
,
metric_type
=
"val_loss"
):
if
(
metric_type
not
in
infos
.
keys
()):
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
self
.
_
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
return
#save best
if
self
.
should_save_best
(
infos
[
metric_type
]):
self
.
save_best_checkpoint_and_update
(
if
self
.
_
should_save_best
(
infos
[
metric_type
]):
self
.
_
save_best_checkpoint_and_update
(
infos
[
metric_type
],
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
#save latest
self
.
save_latest_checkpoint_and_update
(
checkpoint_dir
,
tag_or_iteration
,
self
.
_
save_latest_checkpoint_and_update
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
if
isinstance
(
tag_or_iteration
,
int
):
self
.
save_checkpoint_record
(
checkpoint_dir
,
tag_or_iteration
)
self
.
_
save_checkpoint_record
(
checkpoint_dir
,
tag_or_iteration
)
def
save_best_checkpoint_and_update
(
self
,
metric
,
checkpoint_dir
,
def
_
save_best_checkpoint_and_update
(
self
,
metric
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the worst
if
self
.
best_full
():
if
self
.
_
best_full
():
worst_record_path
=
max
(
self
.
best_records
,
key
=
self
.
best_records
.
get
)
self
.
best_records
.
pop
(
worst_record_path
)
if
(
worst_record_path
not
in
self
.
latest_records
):
logger
.
info
(
"remove the worst checkpoint: {}"
.
format
(
worst_record_path
))
self
.
del_checkpoint
(
checkpoint_dir
,
worst_record_path
)
self
.
_
del_checkpoint
(
checkpoint_dir
,
worst_record_path
)
# add the new one
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
self
.
_
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
best_records
[
tag_or_iteration
]
=
metric
def
save_latest_checkpoint_and_update
(
def
_
save_latest_checkpoint_and_update
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the old
if
self
.
latest_full
():
if
self
.
_
latest_full
():
to_del_fn
=
self
.
latest_records
.
pop
(
0
)
if
(
to_del_fn
not
in
self
.
best_records
.
keys
()):
logger
.
info
(
"remove the latest checkpoint: {}"
.
format
(
to_del_fn
))
self
.
del_checkpoint
(
checkpoint_dir
,
to_del_fn
)
self
.
_
del_checkpoint
(
checkpoint_dir
,
to_del_fn
)
self
.
latest_records
.
append
(
tag_or_iteration
)
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
self
.
_
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
def
del_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
):
def
_
del_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
):
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
tag_or_iteration
))
for
filename
in
glob
.
glob
(
checkpoint_path
+
".*"
):
os
.
remove
(
filename
)
logger
.
info
(
"delete file: {}"
.
format
(
filename
))
def
load_checkpoint_idx
(
self
,
checkpoint_record
:
str
)
->
int
:
def
_
load_checkpoint_idx
(
self
,
checkpoint_record
:
str
)
->
int
:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_path (str): the saved path of checkpoint.
...
...
@@ -134,7 +134,7 @@ class Checkpoint(object):
iteration
=
int
(
latest_checkpoint
.
split
(
":"
)[
-
1
])
return
iteration
def
save_checkpoint_record
(
self
,
checkpoint_dir
:
str
,
iteration
:
int
):
def
_
save_checkpoint_record
(
self
,
checkpoint_dir
:
str
,
iteration
:
int
):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
...
...
@@ -153,65 +153,13 @@ class Checkpoint(object):
for
i
in
self
.
latest_records
:
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
i
))
def
load_last_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs
=
{}
if
checkpoint_path
is
not
None
:
tag
=
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
]
elif
checkpoint_dir
is
not
None
:
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_latest"
)
iteration
=
self
.
load_checkpoint_idx
(
checkpoint_record
)
if
iteration
==
-
1
:
return
configs
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
else
:
raise
ValueError
(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
)
rank
=
dist
.
get_rank
()
params_path
=
checkpoint_path
+
".pdparams"
model_dict
=
paddle
.
load
(
params_path
)
model
.
set_state_dict
(
model_dict
)
logger
.
info
(
"Rank {}: loaded model from {}"
.
format
(
rank
,
params_path
))
optimizer_path
=
checkpoint_path
+
".pdopt"
if
optimizer
and
os
.
path
.
isfile
(
optimizer_path
):
optimizer_dict
=
paddle
.
load
(
optimizer_path
)
optimizer
.
set_state_dict
(
optimizer_dict
)
logger
.
info
(
"Rank {}: loaded optimizer state from {}"
.
format
(
rank
,
optimizer_path
))
info_path
=
re
.
sub
(
'.pdparams$'
,
'.json'
,
params_path
)
if
os
.
path
.
exists
(
info_path
):
with
open
(
info_path
,
'r'
)
as
fin
:
configs
=
json
.
load
(
fin
)
return
configs
def
load_best
_parameters
(
self
,
def
_load
_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
checkpoint_path
=
None
,
checkpoint_file
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
...
...
@@ -221,6 +169,7 @@ class Checkpoint(object):
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
checkpoint_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
...
...
@@ -228,16 +177,16 @@ class Checkpoint(object):
if
checkpoint_path
is
not
None
:
tag
=
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
]
elif
checkpoint_dir
is
not
None
:
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_best"
)
iteration
=
self
.
load_checkpoint_idx
(
checkpoint_record
)
elif
checkpoint_dir
is
not
None
and
checkpoint_file
is
not
None
:
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
checkpoint_file
)
iteration
=
self
.
_
load_checkpoint_idx
(
checkpoint_record
)
if
iteration
==
-
1
:
return
configs
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
else
:
raise
ValueError
(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
"At least one of 'checkpoint_dir' and 'checkpoint_
file' and 'checkpoint_
path' should be specified!"
)
rank
=
dist
.
get_rank
()
...
...
@@ -261,7 +210,7 @@ class Checkpoint(object):
return
configs
@
mp_tools
.
rank_zero_only
def
save_parameters
(
self
,
def
_
save_parameters
(
self
,
checkpoint_dir
:
str
,
tag_or_iteration
:
Union
[
int
,
str
],
model
:
paddle
.
nn
.
Layer
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录