Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
91bc5959
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
91bc5959
编写于
9月 10, 2021
作者:
H
Hui Zhang
提交者:
GitHub
9月 10, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #820 from PaddlePaddle/ctc
more ctc config
上级
5e063adf
2480be8e
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
163 addition
and
112 deletion
+163
-112
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+32
-29
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+31
-28
deepspeech/exps/u2_st/model.py
deepspeech/exps/u2_st/model.py
+31
-28
deepspeech/models/u2.py
deepspeech/models/u2.py
+12
-6
deepspeech/models/u2_st.py
deepspeech/models/u2_st.py
+9
-7
deepspeech/modules/loss.py
deepspeech/modules/loss.py
+3
-3
deepspeech/training/gradclip.py
deepspeech/training/gradclip.py
+9
-7
doc/src/deepspeech_architecture.md
doc/src/deepspeech_architecture.md
+0
-2
examples/aishell/s1/conf/chunk_conformer.yaml
examples/aishell/s1/conf/chunk_conformer.yaml
+2
-0
examples/aishell/s1/conf/conformer.yaml
examples/aishell/s1/conf/conformer.yaml
+2
-0
examples/librispeech/s1/conf/chunk_conformer.yaml
examples/librispeech/s1/conf/chunk_conformer.yaml
+2
-0
examples/librispeech/s1/conf/chunk_transformer.yaml
examples/librispeech/s1/conf/chunk_transformer.yaml
+2
-0
examples/librispeech/s1/conf/conformer.yaml
examples/librispeech/s1/conf/conformer.yaml
+2
-0
examples/librispeech/s1/conf/transformer.yaml
examples/librispeech/s1/conf/transformer.yaml
+3
-1
examples/librispeech/s1/local/train.sh
examples/librispeech/s1/local/train.sh
+1
-1
examples/librispeech/s2/conf/chunk_conformer.yaml
examples/librispeech/s2/conf/chunk_conformer.yaml
+2
-0
examples/librispeech/s2/conf/chunk_transformer.yaml
examples/librispeech/s2/conf/chunk_transformer.yaml
+2
-0
examples/librispeech/s2/conf/conformer.yaml
examples/librispeech/s2/conf/conformer.yaml
+2
-0
examples/librispeech/s2/conf/transformer.yaml
examples/librispeech/s2/conf/transformer.yaml
+2
-0
examples/ted_en_zh/t0/conf/transformer.yaml
examples/ted_en_zh/t0/conf/transformer.yaml
+2
-0
examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml
examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml
+2
-0
examples/timit/s1/conf/transformer.yaml
examples/timit/s1/conf/transformer.yaml
+2
-0
examples/tiny/s1/conf/chunk_confermer.yaml
examples/tiny/s1/conf/chunk_confermer.yaml
+2
-0
examples/tiny/s1/conf/chunk_transformer.yaml
examples/tiny/s1/conf/chunk_transformer.yaml
+2
-0
examples/tiny/s1/conf/conformer.yaml
examples/tiny/s1/conf/conformer.yaml
+2
-0
examples/tiny/s1/conf/transformer.yaml
examples/tiny/s1/conf/transformer.yaml
+2
-0
未找到文件。
deepspeech/exps/u2/model.py
浏览文件 @
91bc5959
...
@@ -34,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
...
@@ -34,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from
deepspeech.models.u2
import
U2Model
from
deepspeech.models.u2
import
U2Model
from
deepspeech.training.optimizer
import
OptimizerFactory
from
deepspeech.training.optimizer
import
OptimizerFactory
from
deepspeech.training.scheduler
import
LRSchedulerFactory
from
deepspeech.training.scheduler
import
LRSchedulerFactory
from
deepspeech.training.timer
import
Timer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
error_rate
from
deepspeech.utils
import
error_rate
...
@@ -184,40 +185,42 @@ class U2Trainer(Trainer):
...
@@ -184,40 +185,42 @@ class U2Trainer(Trainer):
self
.
save
(
tag
=
'init'
)
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
if
self
.
parallel
and
hasattr
(
self
.
train_loader
,
'batch_sampler'
)
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
with
Timer
(
"Epoch-Train Time Cost: {}"
):
try
:
self
.
model
.
train
()
data_start_time
=
time
.
time
()
try
:
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
dataload_time
=
time
.
time
()
-
data_start_time
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
data_start_time
=
time
.
time
()
data_start_time
=
time
.
time
()
except
Exception
as
e
:
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
logger
.
error
(
e
)
dataload_time
=
time
.
time
()
-
data_start_time
raise
e
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
total_loss
,
num_seen_utts
=
self
.
valid
()
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
dist
.
get_world_size
()
>
1
:
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
len
(
self
.
train_loader
))
# the default operator in all_reduce function is sum.
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
dist
.
all_reduce
(
num_seen_utts
)
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
dist
.
all_reduce
(
total_loss
)
data_start_time
=
time
.
time
()
cv_loss
=
total_loss
/
num_seen_utts
except
Exception
as
e
:
cv_loss
=
float
(
cv_loss
)
logger
.
error
(
e
)
else
:
raise
e
cv_loss
=
total_loss
/
num_seen_utts
with
Timer
(
"Eval Time Cost: {}"
):
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
logger
.
info
(
logger
.
info
(
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
...
...
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
91bc5959
...
@@ -32,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader
...
@@ -32,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader
from
deepspeech.models.u2
import
U2Model
from
deepspeech.models.u2
import
U2Model
from
deepspeech.training.optimizer
import
OptimizerFactory
from
deepspeech.training.optimizer
import
OptimizerFactory
from
deepspeech.training.scheduler
import
LRSchedulerFactory
from
deepspeech.training.scheduler
import
LRSchedulerFactory
from
deepspeech.training.timer
import
Timer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
error_rate
from
deepspeech.utils
import
error_rate
...
@@ -190,35 +191,37 @@ class U2Trainer(Trainer):
...
@@ -190,35 +191,37 @@ class U2Trainer(Trainer):
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
with
Timer
(
"Epoch-Train Time Cost: {}"
):
try
:
self
.
model
.
train
()
data_start_time
=
time
.
time
()
try
:
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
dataload_time
=
time
.
time
()
-
data_start_time
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
data_start_time
=
time
.
time
()
data_start_time
=
time
.
time
()
except
Exception
as
e
:
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
logger
.
error
(
e
)
dataload_time
=
time
.
time
()
-
data_start_time
raise
e
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
total_loss
,
num_seen_utts
=
self
.
valid
()
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
dist
.
get_world_size
()
>
1
:
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
len
(
self
.
train_loader
))
# the default operator in all_reduce function is sum.
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
dist
.
all_reduce
(
num_seen_utts
)
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
dist
.
all_reduce
(
total_loss
)
data_start_time
=
time
.
time
()
cv_loss
=
total_loss
/
num_seen_utts
except
Exception
as
e
:
cv_loss
=
float
(
cv_loss
)
logger
.
error
(
e
)
else
:
raise
e
cv_loss
=
total_loss
/
num_seen_utts
with
Timer
(
"Eval Time Cost: {}"
):
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
logger
.
info
(
logger
.
info
(
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
...
...
deepspeech/exps/u2_st/model.py
浏览文件 @
91bc5959
...
@@ -38,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
...
@@ -38,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from
deepspeech.models.u2_st
import
U2STModel
from
deepspeech.models.u2_st
import
U2STModel
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.training.scheduler
import
WarmupLR
from
deepspeech.training.scheduler
import
WarmupLR
from
deepspeech.training.timer
import
Timer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.utils
import
bleu_score
from
deepspeech.utils
import
bleu_score
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
ctc_utils
...
@@ -207,35 +208,37 @@ class U2STTrainer(Trainer):
...
@@ -207,35 +208,37 @@ class U2STTrainer(Trainer):
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
with
Timer
(
"Epoch-Train Time Cost: {}"
):
try
:
self
.
model
.
train
()
data_start_time
=
time
.
time
()
try
:
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
dataload_time
=
time
.
time
()
-
data_start_time
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
data_start_time
=
time
.
time
()
data_start_time
=
time
.
time
()
except
Exception
as
e
:
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
logger
.
error
(
e
)
dataload_time
=
time
.
time
()
-
data_start_time
raise
e
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
total_loss
,
num_seen_utts
=
self
.
valid
()
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
dist
.
get_world_size
()
>
1
:
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
len
(
self
.
train_loader
))
# the default operator in all_reduce function is sum.
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
dist
.
all_reduce
(
num_seen_utts
)
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
dist
.
all_reduce
(
total_loss
)
data_start_time
=
time
.
time
()
cv_loss
=
total_loss
/
num_seen_utts
except
Exception
as
e
:
cv_loss
=
float
(
cv_loss
)
logger
.
error
(
e
)
else
:
raise
e
cv_loss
=
total_loss
/
num_seen_utts
with
Timer
(
"Eval Time Cost: {}"
):
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
logger
.
info
(
logger
.
info
(
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
...
...
deepspeech/models/u2.py
浏览文件 @
91bc5959
...
@@ -115,7 +115,8 @@ class U2BaseModel(nn.Layer):
...
@@ -115,7 +115,8 @@ class U2BaseModel(nn.Layer):
ctc_weight
:
float
=
0.5
,
ctc_weight
:
float
=
0.5
,
ignore_id
:
int
=
IGNORE_ID
,
ignore_id
:
int
=
IGNORE_ID
,
lsm_weight
:
float
=
0.0
,
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
):
length_normalized_loss
:
bool
=
False
,
**
kwargs
):
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
super
().
__init__
()
super
().
__init__
()
...
@@ -661,9 +662,7 @@ class U2BaseModel(nn.Layer):
...
@@ -661,9 +662,7 @@ class U2BaseModel(nn.Layer):
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
elayers_output_cache
,
conformer_cnn_cache
)
# @jit.to_static([
# @jit.to_static
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, apply linear transform and log
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
softmax before ctc
...
@@ -830,6 +829,7 @@ class U2Model(U2BaseModel):
...
@@ -830,6 +829,7 @@ class U2Model(U2BaseModel):
Returns:
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
"""
# cmvn
if
configs
[
'cmvn_file'
]
is
not
None
:
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
configs
[
'cmvn_file_type'
])
...
@@ -839,11 +839,13 @@ class U2Model(U2BaseModel):
...
@@ -839,11 +839,13 @@ class U2Model(U2BaseModel):
else
:
else
:
global_cmvn
=
None
global_cmvn
=
None
# input & output dim
input_dim
=
configs
[
'input_dim'
]
input_dim
=
configs
[
'input_dim'
]
vocab_size
=
configs
[
'output_dim'
]
vocab_size
=
configs
[
'output_dim'
]
assert
input_dim
!=
0
,
input_dim
assert
input_dim
!=
0
,
input_dim
assert
vocab_size
!=
0
,
vocab_size
assert
vocab_size
!=
0
,
vocab_size
# encoder
encoder_type
=
configs
.
get
(
'encoder'
,
'transformer'
)
encoder_type
=
configs
.
get
(
'encoder'
,
'transformer'
)
logger
.
info
(
f
"U2 Encoder type:
{
encoder_type
}
"
)
logger
.
info
(
f
"U2 Encoder type:
{
encoder_type
}
"
)
if
encoder_type
==
'transformer'
:
if
encoder_type
==
'transformer'
:
...
@@ -855,17 +857,21 @@ class U2Model(U2BaseModel):
...
@@ -855,17 +857,21 @@ class U2Model(U2BaseModel):
else
:
else
:
raise
ValueError
(
f
"not support encoder type:
{
encoder_type
}
"
)
raise
ValueError
(
f
"not support encoder type:
{
encoder_type
}
"
)
# decoder
decoder
=
TransformerDecoder
(
vocab_size
,
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
**
configs
[
'decoder_conf'
])
# ctc decoder and ctc loss
model_conf
=
configs
[
'model_conf'
]
ctc
=
CTCDecoder
(
ctc
=
CTCDecoder
(
odim
=
vocab_size
,
odim
=
vocab_size
,
enc_n_units
=
encoder
.
output_size
(),
enc_n_units
=
encoder
.
output_size
(),
blank_id
=
0
,
blank_id
=
0
,
dropout_rate
=
0.0
,
dropout_rate
=
model_conf
[
'ctc_dropoutrate'
]
,
reduction
=
True
,
# sum
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
'instance'
)
grad_norm_type
=
model_conf
[
'ctc_grad_norm_type'
]
)
return
vocab_size
,
encoder
,
decoder
,
ctc
return
vocab_size
,
encoder
,
decoder
,
ctc
...
...
deepspeech/models/u2_st.py
浏览文件 @
91bc5959
...
@@ -413,26 +413,26 @@ class U2STBaseModel(nn.Layer):
...
@@ -413,26 +413,26 @@ class U2STBaseModel(nn.Layer):
best_hyps
=
best_hyps
[:,
1
:]
best_hyps
=
best_hyps
[:,
1
:]
return
best_hyps
return
best_hyps
@
jit
.
to_static
#
@jit.to_static
def
subsampling_rate
(
self
)
->
int
:
def
subsampling_rate
(
self
)
->
int
:
""" Export interface for c++ call, return subsampling_rate of the
""" Export interface for c++ call, return subsampling_rate of the
model
model
"""
"""
return
self
.
encoder
.
embed
.
subsampling_rate
return
self
.
encoder
.
embed
.
subsampling_rate
@
jit
.
to_static
#
@jit.to_static
def
right_context
(
self
)
->
int
:
def
right_context
(
self
)
->
int
:
""" Export interface for c++ call, return right_context of the model
""" Export interface for c++ call, return right_context of the model
"""
"""
return
self
.
encoder
.
embed
.
right_context
return
self
.
encoder
.
embed
.
right_context
@
jit
.
to_static
#
@jit.to_static
def
sos_symbol
(
self
)
->
int
:
def
sos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return sos symbol id of the model
""" Export interface for c++ call, return sos symbol id of the model
"""
"""
return
self
.
sos
return
self
.
sos
@
jit
.
to_static
#
@jit.to_static
def
eos_symbol
(
self
)
->
int
:
def
eos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return eos symbol id of the model
""" Export interface for c++ call, return eos symbol id of the model
"""
"""
...
@@ -468,7 +468,7 @@ class U2STBaseModel(nn.Layer):
...
@@ -468,7 +468,7 @@ class U2STBaseModel(nn.Layer):
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
elayers_output_cache
,
conformer_cnn_cache
)
@
jit
.
to_static
#
@jit.to_static
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, apply linear transform and log
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
softmax before ctc
...
@@ -643,14 +643,16 @@ class U2STModel(U2STBaseModel):
...
@@ -643,14 +643,16 @@ class U2STModel(U2STBaseModel):
decoder
=
TransformerDecoder
(
vocab_size
,
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
**
configs
[
'decoder_conf'
])
# ctc decoder and ctc loss
model_conf
=
configs
[
'model_conf'
]
ctc
=
CTCDecoder
(
ctc
=
CTCDecoder
(
odim
=
vocab_size
,
odim
=
vocab_size
,
enc_n_units
=
encoder
.
output_size
(),
enc_n_units
=
encoder
.
output_size
(),
blank_id
=
0
,
blank_id
=
0
,
dropout_rate
=
0.0
,
dropout_rate
=
model_conf
[
'ctc_dropout_rate'
]
,
reduction
=
True
,
# sum
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
'instance'
)
grad_norm_type
=
model_conf
[
'ctc_grad_norm_type'
]
)
return
vocab_size
,
encoder
,
(
st_decoder
,
decoder
,
ctc
)
return
vocab_size
,
encoder
,
(
st_decoder
,
decoder
,
ctc
)
else
:
else
:
...
...
deepspeech/modules/loss.py
浏览文件 @
91bc5959
...
@@ -36,16 +36,16 @@ class CTCLoss(nn.Layer):
...
@@ -36,16 +36,16 @@ class CTCLoss(nn.Layer):
f
"CTCLoss Loss reduction:
{
reduction
}
, div-bs:
{
batch_average
}
"
)
f
"CTCLoss Loss reduction:
{
reduction
}
, div-bs:
{
batch_average
}
"
)
# instance for norm_by_times
# instance for norm_by_times
# batch
size
for norm_by_batchsize
# batch for norm_by_batchsize
# frame for norm_by_total_logits_len
# frame for norm_by_total_logits_len
assert
grad_norm_type
in
(
'instance'
,
'batch
size
'
,
'frame'
,
None
)
assert
grad_norm_type
in
(
'instance'
,
'batch'
,
'frame'
,
None
)
self
.
norm_by_times
=
False
self
.
norm_by_times
=
False
self
.
norm_by_batchsize
=
False
self
.
norm_by_batchsize
=
False
self
.
norm_by_total_logits_len
=
False
self
.
norm_by_total_logits_len
=
False
logger
.
info
(
f
"CTCLoss Grad Norm Type:
{
grad_norm_type
}
"
)
logger
.
info
(
f
"CTCLoss Grad Norm Type:
{
grad_norm_type
}
"
)
if
grad_norm_type
==
'instance'
:
if
grad_norm_type
==
'instance'
:
self
.
norm_by_times
=
True
self
.
norm_by_times
=
True
if
grad_norm_type
==
'batch
size
'
:
if
grad_norm_type
==
'batch'
:
self
.
norm_by_times
=
True
self
.
norm_by_times
=
True
if
grad_norm_type
==
'frame'
:
if
grad_norm_type
==
'frame'
:
self
.
norm_by_total_logits_len
=
True
self
.
norm_by_total_logits_len
=
True
...
...
deepspeech/training/gradclip.py
浏览文件 @
91bc5959
...
@@ -47,9 +47,10 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
...
@@ -47,9 +47,10 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
sum_square
=
layers
.
reduce_sum
(
square
)
sum_square
=
layers
.
reduce_sum
(
square
)
sum_square_list
.
append
(
sum_square
)
sum_square_list
.
append
(
sum_square
)
# debug log
# debug log, not dump all since slow down train process
logger
.
debug
(
if
i
<
10
:
f
"Grad Before Clip:
{
p
.
name
}
:
{
float
(
sum_square
.
sqrt
())
}
"
)
logger
.
debug
(
f
"Grad Before Clip:
{
p
.
name
}
:
{
float
(
sum_square
.
sqrt
())
}
"
)
# all parameters have been filterd out
# all parameters have been filterd out
if
len
(
sum_square_list
)
==
0
:
if
len
(
sum_square_list
)
==
0
:
...
@@ -75,9 +76,10 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
...
@@ -75,9 +76,10 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
new_grad
=
layers
.
elementwise_mul
(
x
=
g
,
y
=
clip_var
)
new_grad
=
layers
.
elementwise_mul
(
x
=
g
,
y
=
clip_var
)
params_and_grads
.
append
((
p
,
new_grad
))
params_and_grads
.
append
((
p
,
new_grad
))
# debug log
# debug log, not dump all since slow down train process
logger
.
debug
(
if
i
<
10
:
f
"Grad After Clip:
{
p
.
name
}
:
{
float
(
new_grad
.
square
().
sum
().
sqrt
())
}
"
logger
.
debug
(
)
f
"Grad After Clip:
{
p
.
name
}
:
{
float
(
new_grad
.
square
().
sum
().
sqrt
())
}
"
)
return
params_and_grads
return
params_and_grads
doc/src/deepspeech_architecture.md
浏览文件 @
91bc5959
...
@@ -183,5 +183,3 @@ bash run.sh --stage 0 --stop_stage 2 --model_type offline --conf_path conf/deeps
...
@@ -183,5 +183,3 @@ bash run.sh --stage 0 --stop_stage 2 --model_type offline --conf_path conf/deeps
cd examples/aishell/s0
cd examples/aishell/s0
bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml
bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml
```
```
examples/aishell/s1/conf/chunk_conformer.yaml
浏览文件 @
91bc5959
...
@@ -76,6 +76,8 @@ model:
...
@@ -76,6 +76,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/aishell/s1/conf/conformer.yaml
浏览文件 @
91bc5959
...
@@ -71,6 +71,8 @@ model:
...
@@ -71,6 +71,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/librispeech/s1/conf/chunk_conformer.yaml
浏览文件 @
91bc5959
...
@@ -76,6 +76,8 @@ model:
...
@@ -76,6 +76,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/librispeech/s1/conf/chunk_transformer.yaml
浏览文件 @
91bc5959
...
@@ -69,6 +69,8 @@ model:
...
@@ -69,6 +69,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/librispeech/s1/conf/conformer.yaml
浏览文件 @
91bc5959
...
@@ -72,6 +72,8 @@ model:
...
@@ -72,6 +72,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/librispeech/s1/conf/transformer.yaml
浏览文件 @
91bc5959
...
@@ -33,7 +33,7 @@ collator:
...
@@ -33,7 +33,7 @@ collator:
keep_transcription_text
:
False
keep_transcription_text
:
False
sortagrad
:
True
sortagrad
:
True
shuffle_method
:
batch_shuffle
shuffle_method
:
batch_shuffle
num_workers
:
2
num_workers
:
0
# network architecture
# network architecture
...
@@ -67,6 +67,8 @@ model:
...
@@ -67,6 +67,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/librispeech/s1/local/train.sh
浏览文件 @
91bc5959
...
@@ -20,7 +20,7 @@ echo "using ${device}..."
...
@@ -20,7 +20,7 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
10086
seed
=
10086
if
[
${
seed
}
!=
0]
;
then
if
[
${
seed
}
!=
0
]
;
then
export
FLAGS_cudnn_deterministic
=
True
export
FLAGS_cudnn_deterministic
=
True
fi
fi
...
...
examples/librispeech/s2/conf/chunk_conformer.yaml
浏览文件 @
91bc5959
...
@@ -76,6 +76,8 @@ model:
...
@@ -76,6 +76,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/librispeech/s2/conf/chunk_transformer.yaml
浏览文件 @
91bc5959
...
@@ -69,6 +69,8 @@ model:
...
@@ -69,6 +69,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/librispeech/s2/conf/conformer.yaml
浏览文件 @
91bc5959
...
@@ -72,6 +72,8 @@ model:
...
@@ -72,6 +72,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/librispeech/s2/conf/transformer.yaml
浏览文件 @
91bc5959
...
@@ -58,6 +58,8 @@ model:
...
@@ -58,6 +58,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/ted_en_zh/t0/conf/transformer.yaml
浏览文件 @
91bc5959
...
@@ -68,6 +68,8 @@ model:
...
@@ -68,6 +68,8 @@ model:
model_conf
:
model_conf
:
asr_weight
:
0.0
asr_weight
:
0.0
ctc_weight
:
0.0
ctc_weight
:
0.0
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml
浏览文件 @
91bc5959
...
@@ -68,6 +68,8 @@ model:
...
@@ -68,6 +68,8 @@ model:
model_conf
:
model_conf
:
asr_weight
:
0.5
asr_weight
:
0.5
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/timit/s1/conf/transformer.yaml
浏览文件 @
91bc5959
...
@@ -66,6 +66,8 @@ model:
...
@@ -66,6 +66,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/tiny/s1/conf/chunk_confermer.yaml
浏览文件 @
91bc5959
...
@@ -76,6 +76,8 @@ model:
...
@@ -76,6 +76,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/tiny/s1/conf/chunk_transformer.yaml
浏览文件 @
91bc5959
...
@@ -69,6 +69,8 @@ model:
...
@@ -69,6 +69,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/tiny/s1/conf/conformer.yaml
浏览文件 @
91bc5959
...
@@ -72,6 +72,8 @@ model:
...
@@ -72,6 +72,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
examples/tiny/s1/conf/transformer.yaml
浏览文件 @
91bc5959
...
@@ -66,6 +66,8 @@ model:
...
@@ -66,6 +66,8 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录