Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
16210c05
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
16210c05
编写于
6月 25, 2021
作者:
H
Haoxin Ma
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug
上级
91e70a28
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
63 addition
and
60 deletion
+63
-60
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+1
-1
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+62
-59
未找到文件。
deepspeech/training/trainer.py
浏览文件 @
16210c05
...
...
@@ -64,7 +64,7 @@ class Trainer():
The parsed command line arguments.
Examples
--------
>>> def p(config, args):
>>> def
main_s
p(config, args):
>>> exp = Trainer(config, args)
>>> exp.setup()
>>> exp.run()
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
16210c05
...
...
@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
glob
import
json
import
os
import
re
from
pathlib
import
Path
from
typing
import
Union
import
paddle
...
...
@@ -22,25 +24,21 @@ from paddle.optimizer import Optimizer
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils.log
import
Log
import
glob
# import operator
from
pathlib
import
Path
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"Checkpoint"
]
class
Checkpoint
(
object
):
def
__init__
(
self
,
kbest_n
:
int
=
5
,
latest_n
:
int
=
1
):
def
__init__
(
self
,
kbest_n
:
int
=
5
,
latest_n
:
int
=
1
):
self
.
best_records
:
Mapping
[
Path
,
float
]
=
{}
self
.
latest_records
=
[]
self
.
kbest_n
=
kbest_n
self
.
latest_n
=
latest_n
self
.
_save_all
=
(
kbest_n
==
-
1
)
def
should_save_best
(
self
,
metric
:
float
)
->
bool
:
if
not
self
.
best_full
():
return
True
...
...
@@ -53,68 +51,72 @@ class Checkpoint(object):
def
best_full
(
self
):
return
(
not
self
.
_save_all
)
and
len
(
self
.
best_records
)
==
self
.
kbest_n
def
latest_full
(
self
):
return
len
(
self
.
latest_records
)
==
self
.
latest_n
def
add_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
,
metric_type
=
"val_loss"
):
if
(
metric_type
not
in
infos
.
keys
()):
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
def
add_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
,
metric_type
=
"val_loss"
):
if
(
metric_type
not
in
infos
.
keys
()):
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
return
#save best
if
self
.
should_save_best
(
infos
[
metric_type
]):
self
.
save_best_checkpoint_and_update
(
infos
[
metric_type
],
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
save_best_checkpoint_and_update
(
infos
[
metric_type
],
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
#save latest
self
.
save_latest_checkpoint_and_update
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
model
,
optimizer
,
infos
)
if
isinstance
(
tag_or_iteration
,
int
):
self
.
save_checkpoint_record
(
checkpoint_dir
,
tag_or_iteration
)
def
save_best_checkpoint_and_update
(
self
,
metric
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
def
save_best_checkpoint_and_update
(
self
,
metric
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the worst
if
self
.
best_full
():
worst_record_path
=
max
(
self
.
best_records
,
key
=
self
.
best_records
.
get
)
self
.
best_records
.
pop
(
worst_record_path
)
if
(
worst_record_path
not
in
self
.
latest_records
):
logger
.
info
(
"remove the worst checkpoint: {}"
.
format
(
worst_record_path
))
if
(
worst_record_path
not
in
self
.
latest_records
):
logger
.
info
(
"remove the worst checkpoint: {}"
.
format
(
worst_record_path
))
self
.
del_checkpoint
(
checkpoint_dir
,
worst_record_path
)
# add the new one
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
best_records
[
tag_or_iteration
]
=
metric
def
save_latest_checkpoint_and_update
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
def
save_latest_checkpoint_and_update
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the old
if
self
.
latest_full
():
to_del_fn
=
self
.
latest_records
.
pop
(
0
)
if
(
to_del_fn
not
in
self
.
best_records
.
keys
()):
logger
.
info
(
"remove the latest checkpoint: {}"
.
format
(
to_del_fn
))
if
(
to_del_fn
not
in
self
.
best_records
.
keys
()):
logger
.
info
(
"remove the latest checkpoint: {}"
.
format
(
to_del_fn
))
self
.
del_checkpoint
(
checkpoint_dir
,
to_del_fn
)
self
.
latest_records
.
append
(
tag_or_iteration
)
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
def
del_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
):
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
tag_or_iteration
))
for
filename
in
glob
.
glob
(
checkpoint_path
+
".*"
):
"{}"
.
format
(
tag_or_iteration
))
for
filename
in
glob
.
glob
(
checkpoint_path
+
".*"
):
os
.
remove
(
filename
)
logger
.
info
(
"delete file: {}"
.
format
(
filename
))
def
load_checkpoint_idx
(
self
,
checkpoint_record
:
str
)
->
int
:
"""Get the iteration number corresponding to the latest saved checkpoint.
...
...
@@ -131,7 +133,6 @@ class Checkpoint(object):
latest_checkpoint
=
handle
.
readlines
()[
-
1
].
strip
()
iteration
=
int
(
latest_checkpoint
.
split
(
":"
)[
-
1
])
return
iteration
def
save_checkpoint_record
(
self
,
checkpoint_dir
:
str
,
iteration
:
int
):
"""Save the iteration number of the latest model to be checkpoint record.
...
...
@@ -141,9 +142,10 @@ class Checkpoint(object):
Returns:
None
"""
checkpoint_record_latest
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_latest"
)
checkpoint_record_latest
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_latest"
)
checkpoint_record_best
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_best"
)
with
open
(
checkpoint_record_best
,
"w"
)
as
handle
:
for
i
in
self
.
best_records
.
keys
():
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
i
))
...
...
@@ -151,11 +153,11 @@ class Checkpoint(object):
for
i
in
self
.
latest_records
:
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
i
))
def
load_last_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
def
load_last_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
...
...
@@ -173,11 +175,13 @@ class Checkpoint(object):
if
checkpoint_path
is
not
None
:
tag
=
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
]
elif
checkpoint_dir
is
not
None
:
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_latest"
)
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_latest"
)
iteration
=
self
.
load_checkpoint_idx
(
checkpoint_record
)
if
iteration
==
-
1
:
return
configs
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
else
:
raise
ValueError
(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
...
...
@@ -203,11 +207,11 @@ class Checkpoint(object):
configs
=
json
.
load
(
fin
)
return
configs
def
load_best_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
def
load_best_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
...
...
@@ -229,7 +233,8 @@ class Checkpoint(object):
iteration
=
self
.
load_checkpoint_idx
(
checkpoint_record
)
if
iteration
==
-
1
:
return
configs
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
else
:
raise
ValueError
(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
...
...
@@ -255,10 +260,9 @@ class Checkpoint(object):
configs
=
json
.
load
(
fin
)
return
configs
@
mp_tools
.
rank_zero_only
def
save_parameters
(
self
,
checkpoint_dir
:
str
,
def
save_parameters
(
self
,
checkpoint_dir
:
str
,
tag_or_iteration
:
Union
[
int
,
str
],
model
:
paddle
.
nn
.
Layer
,
optimizer
:
Optimizer
=
None
,
...
...
@@ -275,7 +279,7 @@ class Checkpoint(object):
None
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
tag_or_iteration
))
"{}"
.
format
(
tag_or_iteration
))
model_dict
=
model
.
state_dict
()
params_path
=
checkpoint_path
+
".pdparams"
...
...
@@ -293,4 +297,3 @@ class Checkpoint(object):
with
open
(
info_path
,
'w'
)
as
fout
:
data
=
json
.
dumps
(
infos
)
fout
.
write
(
data
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录