Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
08b6213b
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
接近 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
08b6213b
编写于
6月 30, 2021
作者:
H
Haoxin Ma
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix private function
上级
6d92417e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
79 addition
and
40 deletion
+79
-40
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+2
-3
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+77
-37
未找到文件。
deepspeech/training/trainer.py
浏览文件 @
08b6213b
...
...
@@ -151,12 +151,11 @@ class Trainer():
resume training.
"""
scratch
=
None
infos
=
self
.
checkpoint
.
_load
_parameters
(
infos
=
self
.
checkpoint
.
load_latest
_parameters
(
self
.
model
,
self
.
optimizer
,
checkpoint_dir
=
self
.
checkpoint_dir
,
checkpoint_path
=
self
.
args
.
checkpoint_path
,
checkpoint_file
=
'checkpoint_latest'
)
checkpoint_path
=
self
.
args
.
checkpoint_path
)
if
infos
:
# restore from ckpt
self
.
iteration
=
infos
[
"step"
]
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
08b6213b
...
...
@@ -38,23 +38,7 @@ class Checkpoint(object):
self
.
kbest_n
=
kbest_n
self
.
latest_n
=
latest_n
self
.
_save_all
=
(
kbest_n
==
-
1
)
def
_should_save_best
(
self
,
metric
:
float
)
->
bool
:
if
not
self
.
_best_full
():
return
True
# already full
worst_record_path
=
max
(
self
.
best_records
,
key
=
self
.
best_records
.
get
)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric
=
self
.
best_records
[
worst_record_path
]
return
metric
<
worst_metric
def
_best_full
(
self
):
return
(
not
self
.
_save_all
)
and
len
(
self
.
best_records
)
==
self
.
kbest_n
def
_latest_full
(
self
):
return
len
(
self
.
latest_records
)
==
self
.
latest_n
def
add_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
,
...
...
@@ -64,7 +48,7 @@ class Checkpoint(object):
metric_type
=
"val_loss"
):
if
(
metric_type
not
in
infos
.
keys
()):
self
.
_save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
optimizer
,
infos
)
return
#save best
...
...
@@ -73,15 +57,71 @@ class Checkpoint(object):
infos
[
metric_type
],
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
#save latest
self
.
_save_latest_checkpoint_and_update
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
_save_latest_checkpoint_and_update
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
if
isinstance
(
tag_or_iteration
,
int
):
self
.
_save_checkpoint_record
(
checkpoint_dir
,
tag_or_iteration
)
def
load_latest_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return
self
.
_load_parameters
(
model
,
optimizer
,
checkpoint_dir
,
checkpoint_path
,
"checkpoint_latest"
)
def
load_best_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return
self
.
_load_parameters
(
model
,
optimizer
,
checkpoint_dir
,
checkpoint_path
,
"checkpoint_best"
)
def
_should_save_best
(
self
,
metric
:
float
)
->
bool
:
if
not
self
.
_best_full
():
return
True
# already full
worst_record_path
=
max
(
self
.
best_records
,
key
=
self
.
best_records
.
get
)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric
=
self
.
best_records
[
worst_record_path
]
return
metric
<
worst_metric
def
_best_full
(
self
):
return
(
not
self
.
_save_all
)
and
len
(
self
.
best_records
)
==
self
.
kbest_n
def
_latest_full
(
self
):
return
len
(
self
.
latest_records
)
==
self
.
latest_n
def
_save_best_checkpoint_and_update
(
self
,
metric
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the worst
if
self
.
_best_full
():
worst_record_path
=
max
(
self
.
best_records
,
...
...
@@ -93,8 +133,8 @@ class Checkpoint(object):
self
.
_del_checkpoint
(
checkpoint_dir
,
worst_record_path
)
# add the new one
self
.
_save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
_save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
best_records
[
tag_or_iteration
]
=
metric
def
_save_latest_checkpoint_and_update
(
...
...
@@ -108,8 +148,8 @@ class Checkpoint(object):
self
.
_del_checkpoint
(
checkpoint_dir
,
to_del_fn
)
self
.
latest_records
.
append
(
tag_or_iteration
)
self
.
_save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
_save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
def
_del_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
):
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
...
...
@@ -153,13 +193,12 @@ class Checkpoint(object):
for
i
in
self
.
latest_records
:
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
i
))
def
_load_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
,
checkpoint_file
=
None
):
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
,
checkpoint_file
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
...
...
@@ -209,13 +248,14 @@ class Checkpoint(object):
configs
=
json
.
load
(
fin
)
return
configs
@
mp_tools
.
rank_zero_only
def
_save_parameters
(
self
,
checkpoint_dir
:
str
,
tag_or_iteration
:
Union
[
int
,
str
],
model
:
paddle
.
nn
.
Layer
,
optimizer
:
Optimizer
=
None
,
infos
:
dict
=
None
):
checkpoint_dir
:
str
,
tag_or_iteration
:
Union
[
int
,
str
],
model
:
paddle
.
nn
.
Layer
,
optimizer
:
Optimizer
=
None
,
infos
:
dict
=
None
):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录