Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
717fe1e4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
717fe1e4
编写于
6月 29, 2021
作者:
H
Hui Zhang
提交者:
GitHub
6月 29, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #680 from PaddlePaddle/checkpoint
checkpoint refactor to save disk space
上级
1c9c122b
c0f7aac8
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
307 addition
and
129 deletion
+307
-129
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+10
-6
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+256
-122
examples/aishell/s0/conf/deepspeech2.yaml
examples/aishell/s0/conf/deepspeech2.yaml
+3
-0
examples/aishell/s1/conf/chunk_conformer.yaml
examples/aishell/s1/conf/chunk_conformer.yaml
+3
-0
examples/aishell/s1/conf/conformer.yaml
examples/aishell/s1/conf/conformer.yaml
+3
-0
examples/librispeech/s0/conf/deepspeech2.yaml
examples/librispeech/s0/conf/deepspeech2.yaml
+3
-0
examples/librispeech/s1/conf/chunk_confermer.yaml
examples/librispeech/s1/conf/chunk_confermer.yaml
+3
-0
examples/librispeech/s1/conf/chunk_transformer.yaml
examples/librispeech/s1/conf/chunk_transformer.yaml
+3
-0
examples/librispeech/s1/conf/conformer.yaml
examples/librispeech/s1/conf/conformer.yaml
+3
-0
examples/librispeech/s1/conf/transformer.yaml
examples/librispeech/s1/conf/transformer.yaml
+3
-0
examples/tiny/s0/conf/deepspeech2.yaml
examples/tiny/s0/conf/deepspeech2.yaml
+5
-1
examples/tiny/s1/conf/chunk_confermer.yaml
examples/tiny/s1/conf/chunk_confermer.yaml
+3
-0
examples/tiny/s1/conf/chunk_transformer.yaml
examples/tiny/s1/conf/chunk_transformer.yaml
+3
-0
examples/tiny/s1/conf/conformer.yaml
examples/tiny/s1/conf/conformer.yaml
+3
-0
examples/tiny/s1/conf/transformer.yaml
examples/tiny/s1/conf/transformer.yaml
+3
-0
未找到文件。
deepspeech/training/trainer.py
浏览文件 @
717fe1e4
...
@@ -18,8 +18,8 @@ import paddle
...
@@ -18,8 +18,8 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
deepspeech.utils
import
checkpoint
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils.checkpoint
import
Checkpoint
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
__all__
=
[
"Trainer"
]
__all__
=
[
"Trainer"
]
...
@@ -139,7 +139,7 @@ class Trainer():
...
@@ -139,7 +139,7 @@ class Trainer():
"epoch"
:
self
.
epoch
,
"epoch"
:
self
.
epoch
,
"lr"
:
self
.
optimizer
.
get_lr
()
"lr"
:
self
.
optimizer
.
get_lr
()
})
})
checkpoint
.
save_parameters
(
self
.
checkpoint_dir
,
self
.
iteration
self
.
checkpoint
.
add_checkpoint
(
self
.
checkpoint_dir
,
self
.
iteration
if
tag
is
None
else
tag
,
self
.
model
,
if
tag
is
None
else
tag
,
self
.
model
,
self
.
optimizer
,
infos
)
self
.
optimizer
,
infos
)
...
@@ -151,7 +151,7 @@ class Trainer():
...
@@ -151,7 +151,7 @@ class Trainer():
resume training.
resume training.
"""
"""
scratch
=
None
scratch
=
None
infos
=
checkpoint
.
load
_parameters
(
infos
=
self
.
checkpoint
.
load_latest
_parameters
(
self
.
model
,
self
.
model
,
self
.
optimizer
,
self
.
optimizer
,
checkpoint_dir
=
self
.
checkpoint_dir
,
checkpoint_dir
=
self
.
checkpoint_dir
,
...
@@ -180,7 +180,7 @@ class Trainer():
...
@@ -180,7 +180,7 @@ class Trainer():
from_scratch
=
self
.
resume_or_scratch
()
from_scratch
=
self
.
resume_or_scratch
()
if
from_scratch
:
if
from_scratch
:
# save init model, i.e. 0 epoch
# save init model, i.e. 0 epoch
self
.
save
(
tag
=
'init'
)
self
.
save
(
tag
=
'init'
,
infos
=
None
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
if
self
.
parallel
:
...
@@ -263,6 +263,10 @@ class Trainer():
...
@@ -263,6 +263,10 @@ class Trainer():
self
.
checkpoint_dir
=
checkpoint_dir
self
.
checkpoint_dir
=
checkpoint_dir
self
.
checkpoint
=
Checkpoint
(
kbest_n
=
self
.
config
.
training
.
checkpoint
.
kbest_n
,
latest_n
=
self
.
config
.
training
.
checkpoint
.
latest_n
)
@
mp_tools
.
rank_zero_only
@
mp_tools
.
rank_zero_only
def
destory
(
self
):
def
destory
(
self
):
"""Close visualizer to avoid hanging after training"""
"""Close visualizer to avoid hanging after training"""
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
717fe1e4
...
@@ -11,9 +11,11 @@
...
@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
glob
import
json
import
json
import
os
import
os
import
re
import
re
from
pathlib
import
Path
from
typing
import
Union
from
typing
import
Union
import
paddle
import
paddle
...
@@ -25,17 +27,143 @@ from deepspeech.utils.log import Log
...
@@ -25,17 +27,143 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"
load_parameters"
,
"save_parameters
"
]
__all__
=
[
"
Checkpoint
"
]
def
_load_latest_checkpoint
(
checkpoint_dir
:
str
)
->
int
:
class
Checkpoint
(
object
):
def
__init__
(
self
,
kbest_n
:
int
=
5
,
latest_n
:
int
=
1
):
self
.
best_records
:
Mapping
[
Path
,
float
]
=
{}
self
.
latest_records
=
[]
self
.
kbest_n
=
kbest_n
self
.
latest_n
=
latest_n
self
.
_save_all
=
(
kbest_n
==
-
1
)
def
add_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
,
metric_type
=
"val_loss"
):
if
(
metric_type
not
in
infos
.
keys
()):
self
.
_save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
return
#save best
if
self
.
_should_save_best
(
infos
[
metric_type
]):
self
.
_save_best_checkpoint_and_update
(
infos
[
metric_type
],
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
#save latest
self
.
_save_latest_checkpoint_and_update
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
if
isinstance
(
tag_or_iteration
,
int
):
self
.
_save_checkpoint_record
(
checkpoint_dir
,
tag_or_iteration
)
def
load_latest_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return
self
.
_load_parameters
(
model
,
optimizer
,
checkpoint_dir
,
checkpoint_path
,
"checkpoint_latest"
)
def
load_best_parameters
(
self
,
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return
self
.
_load_parameters
(
model
,
optimizer
,
checkpoint_dir
,
checkpoint_path
,
"checkpoint_best"
)
def
_should_save_best
(
self
,
metric
:
float
)
->
bool
:
if
not
self
.
_best_full
():
return
True
# already full
worst_record_path
=
max
(
self
.
best_records
,
key
=
self
.
best_records
.
get
)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric
=
self
.
best_records
[
worst_record_path
]
return
metric
<
worst_metric
def
_best_full
(
self
):
return
(
not
self
.
_save_all
)
and
len
(
self
.
best_records
)
==
self
.
kbest_n
def
_latest_full
(
self
):
return
len
(
self
.
latest_records
)
==
self
.
latest_n
def
_save_best_checkpoint_and_update
(
self
,
metric
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the worst
if
self
.
_best_full
():
worst_record_path
=
max
(
self
.
best_records
,
key
=
self
.
best_records
.
get
)
self
.
best_records
.
pop
(
worst_record_path
)
if
(
worst_record_path
not
in
self
.
latest_records
):
logger
.
info
(
"remove the worst checkpoint: {}"
.
format
(
worst_record_path
))
self
.
_del_checkpoint
(
checkpoint_dir
,
worst_record_path
)
# add the new one
self
.
_save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
self
.
best_records
[
tag_or_iteration
]
=
metric
def
_save_latest_checkpoint_and_update
(
self
,
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
):
# remove the old
if
self
.
_latest_full
():
to_del_fn
=
self
.
latest_records
.
pop
(
0
)
if
(
to_del_fn
not
in
self
.
best_records
.
keys
()):
logger
.
info
(
"remove the latest checkpoint: {}"
.
format
(
to_del_fn
))
self
.
_del_checkpoint
(
checkpoint_dir
,
to_del_fn
)
self
.
latest_records
.
append
(
tag_or_iteration
)
self
.
_save_parameters
(
checkpoint_dir
,
tag_or_iteration
,
model
,
optimizer
,
infos
)
def
_del_checkpoint
(
self
,
checkpoint_dir
,
tag_or_iteration
):
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
tag_or_iteration
))
for
filename
in
glob
.
glob
(
checkpoint_path
+
".*"
):
os
.
remove
(
filename
)
logger
.
info
(
"delete file: {}"
.
format
(
filename
))
def
_load_checkpoint_idx
(
self
,
checkpoint_record
:
str
)
->
int
:
"""Get the iteration number corresponding to the latest saved checkpoint.
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
Args:
checkpoint_dir (str): the directory where checkpoint is saved
.
checkpoint_path (str): the saved path of checkpoint
.
Returns:
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
int: the latest iteration number. -1 for no checkpoint to load.
"""
"""
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
if
not
os
.
path
.
isfile
(
checkpoint_record
):
if
not
os
.
path
.
isfile
(
checkpoint_record
):
return
-
1
return
-
1
...
@@ -45,8 +173,7 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
...
@@ -45,8 +173,7 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
iteration
=
int
(
latest_checkpoint
.
split
(
":"
)[
-
1
])
iteration
=
int
(
latest_checkpoint
.
split
(
":"
)[
-
1
])
return
iteration
return
iteration
def
_save_checkpoint_record
(
self
,
checkpoint_dir
:
str
,
iteration
:
int
):
def
_save_record
(
checkpoint_dir
:
str
,
iteration
:
int
):
"""Save the iteration number of the latest model to be checkpoint record.
"""Save the iteration number of the latest model to be checkpoint record.
Args:
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
checkpoint_dir (str): the directory where checkpoint is saved.
...
@@ -54,17 +181,24 @@ def _save_record(checkpoint_dir: str, iteration: int):
...
@@ -54,17 +181,24 @@ def _save_record(checkpoint_dir: str, iteration: int):
Returns:
Returns:
None
None
"""
"""
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
checkpoint_record_latest
=
os
.
path
.
join
(
checkpoint_dir
,
# Update the latest checkpoint index.
"checkpoint_latest"
)
with
open
(
checkpoint_record
,
"a+"
)
as
handle
:
checkpoint_record_best
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint_best"
)
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
iteration
))
with
open
(
checkpoint_record_best
,
"w"
)
as
handle
:
for
i
in
self
.
best_records
.
keys
():
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
i
))
with
open
(
checkpoint_record_latest
,
"w"
)
as
handle
:
for
i
in
self
.
latest_records
:
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
i
))
def
load_parameters
(
model
,
def
_load_parameters
(
self
,
model
,
optimizer
=
None
,
optimizer
=
None
,
checkpoint_dir
=
None
,
checkpoint_dir
=
None
,
checkpoint_path
=
None
):
checkpoint_path
=
None
,
"""Load a specific model checkpoint from disk.
checkpoint_file
=
None
):
"""Load a last model checkpoint from disk.
Args:
Args:
model (Layer): model to load parameters.
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
optimizer (Optimizer, optional): optimizer to load states if needed.
...
@@ -73,6 +207,7 @@ def load_parameters(model,
...
@@ -73,6 +207,7 @@ def load_parameters(model,
checkpoint_path (str, optional): if specified, load the checkpoint
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
be ignored. Defaults to None.
checkpoint_file "checkpoint_latest" or "checkpoint_best"
Returns:
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
configs (dict): epoch or step, lr and other meta info should be saved.
"""
"""
...
@@ -80,14 +215,16 @@ def load_parameters(model,
...
@@ -80,14 +215,16 @@ def load_parameters(model,
if
checkpoint_path
is
not
None
:
if
checkpoint_path
is
not
None
:
tag
=
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
]
tag
=
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
]
elif
checkpoint_dir
is
not
None
:
elif
checkpoint_dir
is
not
None
and
checkpoint_file
is
not
None
:
iteration
=
_load_latest_checkpoint
(
checkpoint_dir
)
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
checkpoint_file
)
iteration
=
self
.
_load_checkpoint_idx
(
checkpoint_record
)
if
iteration
==
-
1
:
if
iteration
==
-
1
:
return
configs
return
configs
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"At least one of 'checkpoint_dir
' and 'checkpoint_path' should be specified!"
"At least one of 'checkpoint_dir' and 'checkpoint_file
' and 'checkpoint_path' should be specified!"
)
)
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
...
@@ -110,9 +247,9 @@ def load_parameters(model,
...
@@ -110,9 +247,9 @@ def load_parameters(model,
configs
=
json
.
load
(
fin
)
configs
=
json
.
load
(
fin
)
return
configs
return
configs
@
mp_tools
.
rank_zero_only
@
mp_tools
.
rank_zero_only
def
_save_parameters
(
self
,
def
save_parameters
(
checkpoint_dir
:
str
,
checkpoint_dir
:
str
,
tag_or_iteration
:
Union
[
int
,
str
],
tag_or_iteration
:
Union
[
int
,
str
],
model
:
paddle
.
nn
.
Layer
,
model
:
paddle
.
nn
.
Layer
,
optimizer
:
Optimizer
=
None
,
optimizer
:
Optimizer
=
None
,
...
@@ -147,6 +284,3 @@ def save_parameters(checkpoint_dir: str,
...
@@ -147,6 +284,3 @@ def save_parameters(checkpoint_dir: str,
with
open
(
info_path
,
'w'
)
as
fout
:
with
open
(
info_path
,
'w'
)
as
fout
:
data
=
json
.
dumps
(
infos
)
data
=
json
.
dumps
(
infos
)
fout
.
write
(
data
)
fout
.
write
(
data
)
if
isinstance
(
tag_or_iteration
,
int
):
_save_record
(
checkpoint_dir
,
tag_or_iteration
)
examples/aishell/s0/conf/deepspeech2.yaml
浏览文件 @
717fe1e4
...
@@ -48,6 +48,9 @@ training:
...
@@ -48,6 +48,9 @@ training:
weight_decay
:
1e-06
weight_decay
:
1e-06
global_grad_clip
:
3.0
global_grad_clip
:
3.0
log_interval
:
100
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
decoding
:
batch_size
:
128
batch_size
:
128
...
...
examples/aishell/s1/conf/chunk_conformer.yaml
浏览文件 @
717fe1e4
...
@@ -93,6 +93,9 @@ training:
...
@@ -93,6 +93,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
100
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
decoding
:
...
...
examples/aishell/s1/conf/conformer.yaml
浏览文件 @
717fe1e4
...
@@ -88,6 +88,9 @@ training:
...
@@ -88,6 +88,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
100
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
decoding
:
...
...
examples/librispeech/s0/conf/deepspeech2.yaml
浏览文件 @
717fe1e4
...
@@ -48,6 +48,9 @@ training:
...
@@ -48,6 +48,9 @@ training:
weight_decay
:
1e-06
weight_decay
:
1e-06
global_grad_clip
:
5.0
global_grad_clip
:
5.0
log_interval
:
100
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
decoding
:
batch_size
:
128
batch_size
:
128
...
...
examples/librispeech/s1/conf/chunk_confermer.yaml
浏览文件 @
717fe1e4
...
@@ -93,6 +93,9 @@ training:
...
@@ -93,6 +93,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
100
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
decoding
:
...
...
examples/librispeech/s1/conf/chunk_transformer.yaml
浏览文件 @
717fe1e4
...
@@ -86,6 +86,9 @@ training:
...
@@ -86,6 +86,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
100
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
decoding
:
...
...
examples/librispeech/s1/conf/conformer.yaml
浏览文件 @
717fe1e4
...
@@ -89,6 +89,9 @@ training:
...
@@ -89,6 +89,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
100
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
decoding
:
...
...
examples/librispeech/s1/conf/transformer.yaml
浏览文件 @
717fe1e4
...
@@ -84,6 +84,9 @@ training:
...
@@ -84,6 +84,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
100
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
decoding
:
...
...
examples/tiny/s0/conf/deepspeech2.yaml
浏览文件 @
717fe1e4
...
@@ -43,12 +43,16 @@ model:
...
@@ -43,12 +43,16 @@ model:
share_rnn_weights
:
True
share_rnn_weights
:
True
training
:
training
:
n_epoch
:
24
n_epoch
:
10
lr
:
1e-5
lr
:
1e-5
lr_decay
:
1.0
lr_decay
:
1.0
weight_decay
:
1e-06
weight_decay
:
1e-06
global_grad_clip
:
5.0
global_grad_clip
:
5.0
log_interval
:
1
log_interval
:
1
checkpoint
:
kbest_n
:
3
latest_n
:
2
decoding
:
decoding
:
batch_size
:
128
batch_size
:
128
...
...
examples/tiny/s1/conf/chunk_confermer.yaml
浏览文件 @
717fe1e4
...
@@ -91,6 +91,9 @@ training:
...
@@ -91,6 +91,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
1
log_interval
:
1
checkpoint
:
kbest_n
:
10
latest_n
:
1
decoding
:
decoding
:
...
...
examples/tiny/s1/conf/chunk_transformer.yaml
浏览文件 @
717fe1e4
...
@@ -84,6 +84,9 @@ training:
...
@@ -84,6 +84,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
1
log_interval
:
1
checkpoint
:
kbest_n
:
10
latest_n
:
1
decoding
:
decoding
:
...
...
examples/tiny/s1/conf/conformer.yaml
浏览文件 @
717fe1e4
...
@@ -87,6 +87,9 @@ training:
...
@@ -87,6 +87,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
1
log_interval
:
1
checkpoint
:
kbest_n
:
10
latest_n
:
1
decoding
:
decoding
:
...
...
examples/tiny/s1/conf/transformer.yaml
浏览文件 @
717fe1e4
...
@@ -84,6 +84,9 @@ training:
...
@@ -84,6 +84,9 @@ training:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
log_interval
:
1
log_interval
:
1
checkpoint
:
kbest_n
:
10
latest_n
:
1
decoding
:
decoding
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录