Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7e136d08
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
11 个月 前同步成功
通知
204
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7e136d08
编写于
9月 07, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support no_sync for backward; ds support accum grad
上级
41e58631
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
95 addition
and
22 deletion
+95
-22
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+31
-8
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+19
-4
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+18
-3
deepspeech/exps/u2_st/model.py
deepspeech/exps/u2_st/model.py
+18
-4
examples/aishell/s0/conf/deepspeech2.yaml
examples/aishell/s0/conf/deepspeech2.yaml
+1
-0
examples/aishell/s0/conf/deepspeech2_online.yaml
examples/aishell/s0/conf/deepspeech2_online.yaml
+1
-0
examples/librispeech/s0/conf/deepspeech2.yaml
examples/librispeech/s0/conf/deepspeech2.yaml
+2
-1
examples/librispeech/s0/conf/deepspeech2_online.yaml
examples/librispeech/s0/conf/deepspeech2_online.yaml
+2
-1
examples/tiny/s0/conf/deepspeech2.yaml
examples/tiny/s0/conf/deepspeech2.yaml
+1
-0
examples/tiny/s0/conf/deepspeech2_online.yaml
examples/tiny/s0/conf/deepspeech2_online.yaml
+2
-1
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
7e136d08
...
...
@@ -15,6 +15,7 @@
import
os
import
time
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
pathlib
import
Path
from
typing
import
Optional
...
...
@@ -65,29 +66,51 @@ class DeepSpeech2Trainer(Trainer):
super
().
__init__
(
config
,
args
)
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
# forward
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_grad
()
iteration_time
=
time
.
time
()
-
start
losses_np
=
{
'train_loss'
:
float
(
loss
),
}
# loss backward
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
!=
0
:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context
=
self
.
model
.
no_sync
else
:
# Used for single gpu training and DDP gradient synchronization
# processes.
context
=
nullcontext
with
context
():
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
# optimizer step
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
==
0
:
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_grad
()
self
.
iteration
+=
1
iteration_time
=
time
.
time
()
-
start
msg
+=
"train time: {:>.3f}s, "
.
format
(
iteration_time
)
msg
+=
"batch size: {}, "
.
format
(
self
.
config
.
collator
.
batch_size
)
msg
+=
"accum: {}, "
.
format
(
train_conf
.
accum_grad
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
losses_np
.
items
())
logger
.
info
(
msg
)
if
dist
.
get_rank
()
==
0
and
self
.
visualizer
:
for
k
,
v
in
losses_np
.
items
():
# `step -1` since we update `step` after optimizer.step().
self
.
visualizer
.
add_scalar
(
"train/{}"
.
format
(
k
),
v
,
self
.
iteration
)
self
.
iteration
+=
1
self
.
iteration
-
1
)
@
paddle
.
no_grad
()
def
valid
(
self
):
...
...
deepspeech/exps/u2/model.py
浏览文件 @
7e136d08
...
...
@@ -17,6 +17,7 @@ import os
import
sys
import
time
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
pathlib
import
Path
from
typing
import
Optional
...
...
@@ -79,21 +80,35 @@ class U2Trainer(Trainer):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
# forward
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
losses_np
=
{
'loss'
:
float
(
loss
)
*
train_conf
.
accum_grad
}
if
attention_loss
:
losses_np
[
'att_loss'
]
=
float
(
attention_loss
)
if
ctc_loss
:
losses_np
[
'ctc_loss'
]
=
float
(
ctc_loss
)
# loss backward
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
!=
0
:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context
=
self
.
model
.
no_sync
else
:
# Used for single gpu training and DDP gradient synchronization
# processes.
context
=
nullcontext
with
context
():
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
# optimizer step
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
==
0
:
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_grad
()
...
...
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
7e136d08
...
...
@@ -17,6 +17,7 @@ import os
import
sys
import
time
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
pathlib
import
Path
from
typing
import
Optional
...
...
@@ -83,20 +84,34 @@ class U2Trainer(Trainer):
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
# forward
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
losses_np
=
{
'loss'
:
float
(
loss
)
*
train_conf
.
accum_grad
}
if
attention_loss
:
losses_np
[
'att_loss'
]
=
float
(
attention_loss
)
if
ctc_loss
:
losses_np
[
'ctc_loss'
]
=
float
(
ctc_loss
)
# loss backward
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
!=
0
:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context
=
self
.
model
.
no_sync
else
:
# Used for single gpu training and DDP gradient synchronization
# processes.
context
=
nullcontext
with
context
():
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
# optimizer step
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
==
0
:
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_grad
()
...
...
deepspeech/exps/u2_st/model.py
浏览文件 @
7e136d08
...
...
@@ -17,6 +17,7 @@ import os
import
sys
import
time
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
pathlib
import
Path
from
typing
import
Optional
...
...
@@ -83,6 +84,7 @@ class U2STTrainer(Trainer):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
# forward
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
if
isinstance
(
text
,
list
)
and
isinstance
(
text_len
,
list
):
# joint training with ASR. Two decoding texts [translation, transcription]
...
...
@@ -94,18 +96,30 @@ class U2STTrainer(Trainer):
else
:
loss
,
st_loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
losses_np
=
{
'loss'
:
float
(
loss
)
*
train_conf
.
accum_grad
}
losses_np
[
'st_loss'
]
=
float
(
st_loss
)
if
attention_loss
:
losses_np
[
'att_loss'
]
=
float
(
attention_loss
)
if
ctc_loss
:
losses_np
[
'ctc_loss'
]
=
float
(
ctc_loss
)
# loss backward
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
!=
0
:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context
=
self
.
model
.
no_sync
else
:
# Used for single gpu training and DDP gradient synchronization
# processes.
context
=
nullcontext
with
context
():
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
# optimizer step
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
==
0
:
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_grad
()
...
...
examples/aishell/s0/conf/deepspeech2.yaml
浏览文件 @
7e136d08
...
...
@@ -44,6 +44,7 @@ model:
training
:
n_epoch
:
80
accum_grad
:
1
lr
:
2e-3
lr_decay
:
0.83
weight_decay
:
1e-06
...
...
examples/aishell/s0/conf/deepspeech2_online.yaml
浏览文件 @
7e136d08
...
...
@@ -46,6 +46,7 @@ model:
training
:
n_epoch
:
50
accum_grad
:
1
lr
:
2e-3
lr_decay
:
0.9
# 0.83
weight_decay
:
1e-06
...
...
examples/librispeech/s0/conf/deepspeech2.yaml
浏览文件 @
7e136d08
...
...
@@ -11,7 +11,7 @@ data:
max_output_input_ratio
:
.inf
collator
:
batch_size
:
20
batch_size
:
15
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
...
...
@@ -44,6 +44,7 @@ model:
training
:
n_epoch
:
50
accum_grad
:
4
lr
:
1e-3
lr_decay
:
0.83
weight_decay
:
1e-06
...
...
examples/librispeech/s0/conf/deepspeech2_online.yaml
浏览文件 @
7e136d08
...
...
@@ -11,7 +11,7 @@ data:
max_output_input_ratio
:
.inf
collator
:
batch_size
:
20
batch_size
:
15
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
...
...
@@ -46,6 +46,7 @@ model:
training
:
n_epoch
:
50
accum_grad
:
4
lr
:
1e-3
lr_decay
:
0.83
weight_decay
:
1e-06
...
...
examples/tiny/s0/conf/deepspeech2.yaml
浏览文件 @
7e136d08
...
...
@@ -45,6 +45,7 @@ model:
training
:
n_epoch
:
10
accum_grad
:
1
lr
:
1e-5
lr_decay
:
1.0
weight_decay
:
1e-06
...
...
examples/tiny/s0/conf/deepspeech2_online.yaml
浏览文件 @
7e136d08
...
...
@@ -4,7 +4,7 @@ data:
dev_manifest
:
data/manifest.tiny
test_manifest
:
data/manifest.tiny
min_input_len
:
0.0
max_input_len
:
27
.0
max_input_len
:
30
.0
min_output_len
:
0.0
max_output_len
:
400.0
min_output_input_ratio
:
0.05
...
...
@@ -47,6 +47,7 @@ model:
training
:
n_epoch
:
10
accum_grad
:
1
lr
:
1e-5
lr_decay
:
1.0
weight_decay
:
1e-06
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录