Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
91e70a28
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91e70a28
编写于
6月 25, 2021
作者:
H
Haoxin Ma
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multi gpus
上级
8af2eb07
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
105 addition
and
59 deletion
+105
-59
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+8
-10
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+96
-48
examples/tiny/s0/conf/deepspeech2.yaml
examples/tiny/s0/conf/deepspeech2.yaml
+1
-1
未找到文件。
deepspeech/training/trainer.py
浏览文件 @
91e70a28
...
...
@@ -18,8 +18,8 @@ import paddle
from
paddle
import
distributed
as
dist
from
tensorboardX
import
SummaryWriter
from
deepspeech.utils.checkpoint
import
KBestCheckpoint
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils.checkpoint
import
Checkpoint
from
deepspeech.utils.log
import
Log
__all__
=
[
"Trainer"
]
...
...
@@ -64,7 +64,7 @@ class Trainer():
The parsed command line arguments.
Examples
--------
>>> def
main_s
p(config, args):
>>> def p(config, args):
>>> exp = Trainer(config, args)
>>> exp.setup()
>>> exp.run()
...
...
@@ -140,11 +140,8 @@ class Trainer():
"lr"
:
self
.
optimizer
.
get_lr
()
})
self
.
checkpoint
.
add_checkpoint
(
self
.
checkpoint_dir
,
self
.
iteration
if
tag
is
None
else
tag
,
self
.
model
,
self
.
optimizer
,
infos
)
# checkpoint.save_parameters(self.checkpoint_dir, self.iteration
# if tag is None else tag, self.model,
# self.optimizer, infos)
if
tag
is
None
else
tag
,
self
.
model
,
self
.
optimizer
,
infos
)
def
resume_or_scratch
(
self
):
"""Resume from latest checkpoint at checkpoints in the output
...
...
@@ -154,7 +151,7 @@ class Trainer():
resume training.
"""
scratch
=
None
infos
=
self
.
checkpoint
.
load_parameters
(
infos
=
self
.
checkpoint
.
load_
last_
parameters
(
self
.
model
,
self
.
optimizer
,
checkpoint_dir
=
self
.
checkpoint_dir
,
...
...
@@ -266,8 +263,9 @@ class Trainer():
self
.
checkpoint_dir
=
checkpoint_dir
self
.
checkpoint
=
KBestCheckpoint
(
max_size
=
self
.
config
.
training
.
checkpoint
.
kbest_n
,
last_size
=
self
.
config
.
training
.
checkpoint
.
latest_n
)
self
.
checkpoint
=
Checkpoint
(
kbest_n
=
self
.
config
.
training
.
checkpoint
.
kbest_n
,
latest_n
=
self
.
config
.
training
.
checkpoint
.
latest_n
)
@
mp_tools
.
rank_zero_only
def
destory
(
self
):
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
91e70a28
...
...
@@ -24,20 +24,22 @@ from deepspeech.utils import mp_tools
from
deepspeech.utils.log
import
Log
import
glob
# import operator
from
pathlib
import
Path
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"
load_parameters"
,
"save_parameters
"
]
__all__
=
[
"
Checkpoint
"
]
class
KBest
Checkpoint
(
object
):
class
Checkpoint
(
object
):
def
__init__
(
self
,
max_size
:
int
=
5
,
la
st_size
:
int
=
1
):
kbest_n
:
int
=
5
,
la
test_n
:
int
=
1
):
self
.
best_records
:
Mapping
[
Path
,
float
]
=
{}
self
.
last_records
=
[]
self
.
max_size
=
max_size
self
.
la
st_size
=
last_size
self
.
_save_all
=
(
max_size
==
-
1
)
self
.
la
te
st_records
=
[]
self
.
kbest_n
=
kbest_n
self
.
la
test_n
=
latest_n
self
.
_save_all
=
(
kbest_n
==
-
1
)
def
should_save_best
(
self
,
metric
:
float
)
->
bool
:
if
not
self
.
best_full
():
...
...
@@ -45,36 +47,36 @@ class KBestCheckpoint(object):
# already full
worst_record_path
=
max
(
self
.
best_records
,
key
=
self
.
best_records
.
get
)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric
=
self
.
best_records
[
worst_record_path
]
return
metric
<
worst_metric
def
best_full
(
self
):
return
(
not
self
.
_save_all
)
and
len
(
self
.
best_records
)
==
self
.
max_size
return
(
not
self
.
_save_all
)
and
len
(
self
.
best_records
)
==
self
.
kbest_n
def
last_full
(
self
):
return
len
(
self
.
la
st_records
)
==
self
.
last_size
def
la
te
st_full
(
self
):
return
len
(
self
.
la
test_records
)
==
self
.
latest_n
def
add_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
if
(
"val_loss"
not
in
infos
.
keys
()):
def
add_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
,
metric_type
=
"val_loss"
):
if
(
metric_type
not
in
infos
.
keys
()):
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
return
#save best
if
self
.
should_save_best
(
infos
[
"val_loss"
]):
self
.
save_
checkpoint_and_update
(
infos
[
"val_loss"
],
if
self
.
should_save_best
(
infos
[
metric_type
]):
self
.
save_
best_checkpoint_and_update
(
infos
[
metric_type
],
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
#save last
self
.
save_last_checkpoint_and_update
(
checkpoint_dir
,
tag_or_iteration
,
#save la
te
st
self
.
save_la
te
st_checkpoint_and_update
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
if
isinstance
(
tag_or_iteration
,
int
):
self
.
_save
_record
(
checkpoint_dir
,
tag_or_iteration
)
self
.
save_checkpoint
_record
(
checkpoint_dir
,
tag_or_iteration
)
def
save_checkpoint_and_update
(
self
,
metric
,
def
save_
best_
checkpoint_and_update
(
self
,
metric
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the worst
...
...
@@ -82,9 +84,8 @@ class KBestCheckpoint(object):
worst_record_path
=
max
(
self
.
best_records
,
key
=
self
.
best_records
.
get
)
self
.
best_records
.
pop
(
worst_record_path
)
if
(
worst_record_path
not
in
self
.
last_records
):
print
(
'----to remove (best)----'
)
print
(
worst_record_path
)
if
(
worst_record_path
not
in
self
.
latest_records
):
logger
.
info
(
"remove the worst checkpoint: {}"
.
format
(
worst_record_path
))
self
.
del_checkpoint
(
checkpoint_dir
,
worst_record_path
)
# add the new one
...
...
@@ -92,22 +93,18 @@ class KBestCheckpoint(object):
model
,
optimizer
,
infos
)
self
.
best_records
[
tag_or_iteration
]
=
metric
def
save_last_checkpoint_and_update
(
self
,
checkpoint_dir
,
tag_or_iteration
,
def
save_la
te
st_checkpoint_and_update
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the old
if
self
.
last_full
():
to_del_fn
=
self
.
last_records
.
pop
(
0
)
if
self
.
la
te
st_full
():
to_del_fn
=
self
.
la
te
st_records
.
pop
(
0
)
if
(
to_del_fn
not
in
self
.
best_records
.
keys
()):
print
(
'----to remove (last)----'
)
print
(
to_del_fn
)
logger
.
info
(
"remove the latest checkpoint: {}"
.
format
(
to_del_fn
))
self
.
del_checkpoint
(
checkpoint_dir
,
to_del_fn
)
self
.
last_records
.
append
(
tag_or_iteration
)
self
.
la
te
st_records
.
append
(
tag_or_iteration
)
self
.
save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
# with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as handle:
# for iteration in self.best_records
# handle.write("model_checkpoint_path:{}\n".format(iteration))
def
del_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
):
...
...
@@ -115,18 +112,17 @@ class KBestCheckpoint(object):
"{}"
.
format
(
tag_or_iteration
))
for
filename
in
glob
.
glob
(
checkpoint_path
+
".*"
):
os
.
remove
(
filename
)
print
(
"delete file: "
+
filename
)
logger
.
info
(
"delete file: {}"
.
format
(
filename
)
)
def
_load_latest_checkpoint
(
self
,
checkpoint_dir
:
str
)
->
int
:
def
load_checkpoint_idx
(
self
,
checkpoint_record
:
str
)
->
int
:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_
dir (str): the directory where checkpoint is saved
.
checkpoint_
path (str): the saved path of checkpoint
.
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
"""
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_last"
)
if
not
os
.
path
.
isfile
(
checkpoint_record
):
return
-
1
...
...
@@ -135,9 +131,9 @@ class KBestCheckpoint(object):
latest_checkpoint
=
handle
.
readlines
()[
-
1
].
strip
()
iteration
=
int
(
latest_checkpoint
.
split
(
":"
)[
-
1
])
return
iteration
def
_save_record
(
self
,
checkpoint_dir
:
str
,
iteration
:
int
):
def
save_checkpoint_record
(
self
,
checkpoint_dir
:
str
,
iteration
:
int
):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
...
...
@@ -145,24 +141,22 @@ class KBestCheckpoint(object):
Returns:
None
"""
checkpoint_record_la
st
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_la
st"
)
checkpoint_record_la
test
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_late
st"
)
checkpoint_record_best
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_best"
)
# Update the latest checkpoint index.
# with open(checkpoint_record, "a+") as handle:
# handle.write("model_checkpoint_path:{}\n".format(iteration))
with
open
(
checkpoint_record_best
,
"w"
)
as
handle
:
for
i
in
self
.
best_records
.
keys
():
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
i
))
with
open
(
checkpoint_record_last
,
"w"
)
as
handle
:
for
i
in
self
.
last_records
:
with
open
(
checkpoint_record_la
te
st
,
"w"
)
as
handle
:
for
i
in
self
.
la
te
st_records
:
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
i
))
def
load_parameters
(
self
,
model
,
def
load_
last_
parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a
specific
model checkpoint from disk.
"""Load a
last
model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
...
...
@@ -179,7 +173,8 @@ class KBestCheckpoint(object):
if
checkpoint_path
is
not
None
:
tag
=
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
]
elif
checkpoint_dir
is
not
None
:
iteration
=
self
.
_load_latest_checkpoint
(
checkpoint_dir
)
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_latest"
)
iteration
=
self
.
load_checkpoint_idx
(
checkpoint_record
)
if
iteration
==
-
1
:
return
configs
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
...
...
@@ -209,6 +204,59 @@ class KBestCheckpoint(object):
return
configs
def
load_best_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs
=
{}
if
checkpoint_path
is
not
None
:
tag
=
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
]
elif
checkpoint_dir
is
not
None
:
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_best"
)
iteration
=
self
.
load_checkpoint_idx
(
checkpoint_record
)
if
iteration
==
-
1
:
return
configs
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
else
:
raise
ValueError
(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
)
rank
=
dist
.
get_rank
()
params_path
=
checkpoint_path
+
".pdparams"
model_dict
=
paddle
.
load
(
params_path
)
model
.
set_state_dict
(
model_dict
)
logger
.
info
(
"Rank {}: loaded model from {}"
.
format
(
rank
,
params_path
))
optimizer_path
=
checkpoint_path
+
".pdopt"
if
optimizer
and
os
.
path
.
isfile
(
optimizer_path
):
optimizer_dict
=
paddle
.
load
(
optimizer_path
)
optimizer
.
set_state_dict
(
optimizer_dict
)
logger
.
info
(
"Rank {}: loaded optimizer state from {}"
.
format
(
rank
,
optimizer_path
))
info_path
=
re
.
sub
(
'.pdparams$'
,
'.json'
,
params_path
)
if
os
.
path
.
exists
(
info_path
):
with
open
(
info_path
,
'r'
)
as
fin
:
configs
=
json
.
load
(
fin
)
return
configs
@
mp_tools
.
rank_zero_only
def
save_parameters
(
self
,
checkpoint_dir
:
str
,
tag_or_iteration
:
Union
[
int
,
str
],
...
...
examples/tiny/s0/conf/deepspeech2.yaml
浏览文件 @
91e70a28
...
...
@@ -43,7 +43,7 @@ model:
share_rnn_weights
:
True
training
:
n_epoch
:
6
n_epoch
:
10
lr
:
1e-5
lr_decay
:
1.0
weight_decay
:
1e-06
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录