Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5ea181b7
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5ea181b7
编写于
4月 14, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix train logitc
上级
b5bbfc5e
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
81 addition
and
72 deletion
+81
-72
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+5
-5
deepspeech/exps/u2/bin/train.py
deepspeech/exps/u2/bin/train.py
+6
-1
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+38
-33
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+22
-24
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+9
-7
examples/aishell/s1/conf/conformer.yaml
examples/aishell/s1/conf/conformer.yaml
+0
-1
examples/tiny/s1/conf/conformer.yaml
examples/tiny/s1/conf/conformer.yaml
+1
-1
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
5ea181b7
...
...
@@ -45,9 +45,10 @@ class DeepSpeech2Trainer(Trainer):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
def
train_batch
(
self
,
batch_data
):
start
=
time
.
time
()
def
train_batch
(
self
,
batch_data
,
msg
):
self
.
model
.
train
()
start
=
time
.
time
()
loss
=
self
.
model
(
*
batch_data
)
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
...
...
@@ -59,10 +60,8 @@ class DeepSpeech2Trainer(Trainer):
losses_np
=
{
'train_loss'
:
float
(
loss
),
}
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"time: {:>.3f}s, "
.
format
(
iteration_time
)
msg
+=
"batch size: {}, "
.
format
(
self
.
config
.
data
.
batch_size
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
losses_np
.
items
())
self
.
logger
.
info
(
msg
)
...
...
@@ -71,6 +70,7 @@ class DeepSpeech2Trainer(Trainer):
for
k
,
v
in
losses_np
.
items
():
self
.
visualizer
.
add_scalar
(
"train/{}"
.
format
(
k
),
v
,
self
.
iteration
)
self
.
iteration
+=
1
@
mp_tools
.
rank_zero_only
@
paddle
.
no_grad
()
...
...
deepspeech/exps/u2/bin/train.py
浏览文件 @
5ea181b7
...
...
@@ -13,6 +13,8 @@
# limitations under the License.
"""Trainer for U2 model."""
import
os
import
cProfile
from
paddle
import
distributed
as
dist
from
deepspeech.utils.utility
import
print_arguments
...
...
@@ -52,4 +54,7 @@ if __name__ == "__main__":
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
print
(
config
,
file
=
f
)
main
(
config
,
args
)
# Setting for profiling
pr
=
cProfile
.
Profile
()
pr
.
runcall
(
main
,
config
,
args
)
pr
.
dump_stats
(
os
.
path
.
join
(
'.'
,
'train.profile'
))
deepspeech/exps/u2/model.py
浏览文件 @
5ea181b7
...
...
@@ -80,54 +80,60 @@ class U2Trainer(Trainer):
self
.
model
.
train
()
start
=
time
.
time
()
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch_data
)
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
if
self
.
iteration
%
train_conf
.
accum_grad
==
0
:
losses_np
=
{
'train_loss'
:
float
(
loss
)
*
train_conf
.
accum_grad
,
'train_att_loss'
:
float
(
attention_loss
),
'train_ctc_loss'
:
float
(
ctc_loss
),
}
if
(
self
.
iteration
+
1
)
%
train_conf
.
accum_grad
==
0
:
if
dist
.
get_rank
()
==
0
and
self
.
visualizer
:
for
k
,
v
in
losses_np
.
items
():
self
.
visualizer
.
add_scalar
(
"train/{}"
.
format
(
k
),
v
,
self
.
iteration
)
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_grad
()
self
.
lr_scheduler
.
step
()
self
.
iteration
+=
1
iteration_time
=
time
.
time
()
-
start
losses_np
=
{
'train_loss'
:
float
(
loss
),
'train_att_loss'
:
float
(
attention_loss
),
'train_ctc_loss'
:
float
(
ctc_loss
),
}
msg
+=
"time: {:>.3f}s, "
.
format
(
iteration_time
)
msg
+=
"batch size: {}, "
.
format
(
self
.
config
.
data
.
batch_size
)
msg
+=
"accum: {}, "
.
format
(
train_conf
.
accum_grad
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
losses_np
.
items
())
if
self
.
iteration
%
train_conf
.
log_interval
==
0
:
if
(
self
.
iteration
+
1
)
%
train_conf
.
log_interval
==
0
:
msg
+=
"time: {:>.3f}s, "
.
format
(
iteration_time
)
msg
+=
"batch size: {}, "
.
format
(
self
.
config
.
data
.
batch_size
)
msg
+=
"accum: {}, "
.
format
(
train_conf
.
accum_grad
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
losses_np
.
items
())
self
.
logger
.
info
(
msg
)
# display
if
dist
.
get_rank
()
==
0
and
self
.
visualizer
:
for
k
,
v
in
losses_np
.
items
():
self
.
visualizer
.
add_scalar
(
"train/{}"
.
format
(
k
),
v
,
self
.
iteration
)
def
train
(
self
):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
# script_model = paddle.jit.to_static(self.model)
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch
=
self
.
resume_or_scratch
()
if
from_scratch
:
# save init model, i.e. 0 epoch
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<=
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
try
:
data_start_time
=
time
.
time
()
for
batch
in
self
.
train_loader
:
...
...
@@ -135,19 +141,18 @@ class U2Trainer(Trainer):
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"lr: {
}, "
.
foram
t
(
self
.
lr_scheduler
())
msg
+=
"lr: {
:>.8f}, "
.
forma
t
(
self
.
lr_scheduler
())
msg
+=
"dataloader time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
iteration
+=
1
self
.
train_batch
(
batch
,
msg
)
data_start_time
=
time
.
time
()
except
Exception
as
e
:
self
.
logger
.
error
(
e
)
raise
e
self
.
valid
()
self
.
save
()
self
.
new_epoch
()
@
mp_tools
.
rank_zero_only
@
paddle
.
no_grad
()
def
valid
(
self
):
...
...
@@ -263,12 +268,12 @@ class U2Trainer(Trainer):
lr_scheduler
=
paddle
.
optimizer
.
lr
.
ExponentialDecay
(
learning_rate
=
optim_conf
.
lr
,
gamma
=
scheduler_conf
.
lr_decay
,
verbose
=
Tru
e
)
verbose
=
Fals
e
)
elif
scheduler_type
==
'warmuplr'
:
lr_scheduler
=
WarmupLR
(
learning_rate
=
optim_conf
.
lr
,
warmup_steps
=
scheduler_conf
.
warmup_steps
,
verbose
=
Tru
e
)
verbose
=
Fals
e
)
else
:
raise
ValueError
(
f
"Not support scheduler:
{
scheduler_type
}
"
)
...
...
deepspeech/training/trainer.py
浏览文件 @
5ea181b7
...
...
@@ -127,7 +127,7 @@ class Trainer():
dist
.
init_parallel_env
()
@
mp_tools
.
rank_zero_only
def
save
(
self
,
infos
=
None
):
def
save
(
self
,
tag
=
None
,
infos
=
None
):
"""Save checkpoint (model parameters and optimizer states).
"""
if
infos
is
None
:
...
...
@@ -136,8 +136,9 @@ class Trainer():
"epoch"
:
self
.
epoch
,
"lr"
:
self
.
optimizer
.
get_lr
(),
}
checkpoint
.
save_parameters
(
self
.
checkpoint_dir
,
self
.
iteration
,
self
.
model
,
self
.
optimizer
,
infos
)
checkpoint
.
save_parameters
(
self
.
checkpoint_dir
,
self
.
iteration
if
tag
is
None
else
tag
,
self
.
model
,
self
.
optimizer
,
infos
)
def
resume_or_scratch
(
self
):
"""Resume from latest checkpoint at checkpoints in the output
...
...
@@ -146,6 +147,7 @@ class Trainer():
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
scratch
=
None
infos
=
checkpoint
.
load_parameters
(
self
.
model
,
self
.
optimizer
,
...
...
@@ -155,44 +157,41 @@ class Trainer():
# restore from ckpt
self
.
iteration
=
infos
[
"step"
]
self
.
epoch
=
infos
[
"epoch"
]
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
return
False
scratch
=
False
else
:
# from scratch, epoch and iteration init with zero
# save init model, i.e. 0 epoch
self
.
save
()
# self.epoch start from 1.
self
.
new_epoch
()
return
True
scratch
=
True
return
scratch
def
new_epoch
(
self
):
"""Reset the train loader
and increment ``epoch`
`.
"""Reset the train loader
seed and increment `epoch
`.
"""
self
.
epoch
+=
1
if
self
.
parallel
:
# batch sampler epoch start from 0
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
epoch
+=
1
def
train
(
self
):
"""The training process.
"""
"""The training process control by epoch."""
from_scratch
=
self
.
resume_or_scratch
()
if
from_scratch
:
# save init model, i.e. 0 epoch
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
=
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
try
:
data_start_time
=
time
.
time
()
for
batch
in
self
.
train_loader
:
dataload_time
=
time
.
time
()
-
data_start_time
# iteration start from 1.
self
.
iteration
+=
1
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"dataloader time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch
,
msg
)
data_start_time
=
time
.
time
()
...
...
@@ -202,7 +201,6 @@ class Trainer():
self
.
valid
()
self
.
save
()
# lr control by epoch
self
.
lr_scheduler
.
step
()
self
.
new_epoch
()
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
5ea181b7
...
...
@@ -16,6 +16,7 @@ import os
import
logging
import
re
import
json
from
typing
import
Union
import
paddle
from
paddle
import
distributed
as
dist
...
...
@@ -79,7 +80,7 @@ def load_parameters(model,
configs
=
{}
if
checkpoint_path
is
not
None
:
iteration
=
int
(
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
])
tag
=
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
]
elif
checkpoint_dir
is
not
None
:
iteration
=
_load_latest_checkpoint
(
checkpoint_dir
)
if
iteration
==
-
1
:
...
...
@@ -113,14 +114,14 @@ def load_parameters(model,
@
mp_tools
.
rank_zero_only
def
save_parameters
(
checkpoint_dir
:
str
,
iteration
:
int
,
tag_or_iteration
:
Union
[
int
,
str
]
,
model
:
paddle
.
nn
.
Layer
,
optimizer
:
Optimizer
=
None
,
infos
:
dict
=
None
):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int
): the latest iteration(step or epoch) number.
tag_or_iteration (int or str
): the latest iteration(step or epoch) number.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None.
...
...
@@ -128,7 +129,8 @@ def save_parameters(checkpoint_dir: str,
Returns:
None
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
tag_or_iteration
))
model_dict
=
model
.
state_dict
()
params_path
=
checkpoint_path
+
".pdparams"
...
...
@@ -142,10 +144,10 @@ def save_parameters(checkpoint_dir: str,
logger
.
info
(
"Saved optimzier state to {}"
.
format
(
optimizer_path
))
info_path
=
re
.
sub
(
'.pdparams$'
,
'.json'
,
params_path
)
if
infos
is
None
:
infos
=
{}
infos
=
{}
if
infos
is
None
else
infos
with
open
(
info_path
,
'w'
)
as
fout
:
data
=
json
.
dumps
(
infos
)
fout
.
write
(
data
)
_save_checkpoint
(
checkpoint_dir
,
iteration
)
if
isinstance
(
tag_or_iteration
,
int
):
_save_checkpoint
(
checkpoint_dir
,
tag_or_iteration
)
examples/aishell/s1/conf/conformer.yaml
浏览文件 @
5ea181b7
...
...
@@ -6,7 +6,6 @@ data:
vocab_filepath
:
data/vocab.txt
unit_type
:
'
char'
spm_model_prefix
:
'
'
mean_std_filepath
:
"
"
augmentation_config
:
conf/augmentation.json
batch_size
:
64
min_input_len
:
0.5
...
...
examples/tiny/s1/conf/conformer.yaml
浏览文件 @
5ea181b7
...
...
@@ -12,7 +12,7 @@ data:
min_input_len
:
0.5
max_input_len
:
20.0
min_output_len
:
0.0
max_output_len
:
400
max_output_len
:
400
.0
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
raw_wav
:
True
# use raw_wav or kaldi feature
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录