Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
38d95784
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
38d95784
编写于
8月 10, 2021
作者:
H
Hui Zhang
提交者:
GitHub
8月 10, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #735 from Jackwaterveg/ds2_online
Ds2 online
上级
99050b70
61fe292c
变更
29
隐藏空白更改
内联
并排
Showing
29 changed file
with
1036 addition
and
126 deletion
+1036
-126
deepspeech/exps/deepspeech2/bin/export.py
deepspeech/exps/deepspeech2/bin/export.py
+5
-1
deepspeech/exps/deepspeech2/bin/test.py
deepspeech/exps/deepspeech2/bin/test.py
+5
-1
deepspeech/exps/deepspeech2/bin/train.py
deepspeech/exps/deepspeech2/bin/train.py
+5
-1
deepspeech/exps/deepspeech2/config.py
deepspeech/exps/deepspeech2/config.py
+13
-15
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+44
-63
deepspeech/io/sampler.py
deepspeech/io/sampler.py
+1
-1
deepspeech/models/ds2/deepspeech2.py
deepspeech/models/ds2/deepspeech2.py
+33
-0
deepspeech/models/ds2_online/__init__.py
deepspeech/models/ds2_online/__init__.py
+17
-0
deepspeech/models/ds2_online/conv.py
deepspeech/models/ds2_online/conv.py
+35
-0
deepspeech/models/ds2_online/deepspeech2.py
deepspeech/models/ds2_online/deepspeech2.py
+427
-0
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+3
-3
env.sh
env.sh
+1
-1
examples/aishell/s0/conf/deepspeech2_online.yaml
examples/aishell/s0/conf/deepspeech2_online.yaml
+67
-0
examples/aishell/s0/local/export.sh
examples/aishell/s0/local/export.sh
+5
-4
examples/aishell/s0/local/test.sh
examples/aishell/s0/local/test.sh
+5
-3
examples/aishell/s0/local/train.sh
examples/aishell/s0/local/train.sh
+5
-3
examples/aishell/s0/run.sh
examples/aishell/s0/run.sh
+4
-3
examples/librispeech/s0/conf/deepspeech2_online.yaml
examples/librispeech/s0/conf/deepspeech2_online.yaml
+67
-0
examples/librispeech/s0/local/export.sh
examples/librispeech/s0/local/export.sh
+5
-4
examples/librispeech/s0/local/test.sh
examples/librispeech/s0/local/test.sh
+5
-3
examples/librispeech/s0/local/train.sh
examples/librispeech/s0/local/train.sh
+5
-3
examples/librispeech/s0/run.sh
examples/librispeech/s0/run.sh
+4
-3
examples/tiny/s0/conf/deepspeech2_online.yaml
examples/tiny/s0/conf/deepspeech2_online.yaml
+69
-0
examples/tiny/s0/local/export.sh
examples/tiny/s0/local/export.sh
+5
-4
examples/tiny/s0/local/test.sh
examples/tiny/s0/local/test.sh
+5
-3
examples/tiny/s0/local/train.sh
examples/tiny/s0/local/train.sh
+5
-3
examples/tiny/s0/run.sh
examples/tiny/s0/run.sh
+4
-3
tests/deepspeech2_model_test.py
tests/deepspeech2_model_test.py
+1
-1
tests/deepspeech2_online_model_test.py
tests/deepspeech2_online_model_test.py
+186
-0
未找到文件。
deepspeech/exps/deepspeech2/bin/export.py
浏览文件 @
38d95784
...
...
@@ -30,11 +30,15 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
if
args
.
model_type
is
None
:
args
.
model_type
=
'offline'
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
print_arguments
(
args
)
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
(
args
.
model_type
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
...
...
deepspeech/exps/deepspeech2/bin/test.py
浏览文件 @
38d95784
...
...
@@ -30,11 +30,15 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
if
args
.
model_type
is
None
:
args
.
model_type
=
'offline'
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
(
args
.
model_type
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
...
...
deepspeech/exps/deepspeech2/bin/train.py
浏览文件 @
38d95784
...
...
@@ -35,11 +35,15 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
if
args
.
model_type
is
None
:
args
.
model_type
=
'offline'
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
print_arguments
(
args
,
globals
())
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
(
args
.
model_type
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
...
...
deepspeech/exps/deepspeech2/config.py
浏览文件 @
38d95784
...
...
@@ -18,21 +18,19 @@ from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from
deepspeech.io.collator
import
SpeechCollator
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.models.ds2
import
DeepSpeech2Model
_C
=
CfgNode
()
_C
.
data
=
ManifestDataset
.
params
()
_C
.
collator
=
SpeechCollator
.
params
()
_C
.
model
=
DeepSpeech2Model
.
params
()
_C
.
training
=
DeepSpeech2Trainer
.
params
()
_C
.
decoding
=
DeepSpeech2Tester
.
params
()
def
get_cfg_defaults
():
from
deepspeech.models.ds2_online
import
DeepSpeech2ModelOnline
def
get_cfg_defaults
(
model_type
=
'offline'
):
_C
=
CfgNode
()
_C
.
data
=
ManifestDataset
.
params
()
_C
.
collator
=
SpeechCollator
.
params
()
_C
.
training
=
DeepSpeech2Trainer
.
params
()
_C
.
decoding
=
DeepSpeech2Tester
.
params
()
if
model_type
==
'offline'
:
_C
.
model
=
DeepSpeech2Model
.
params
()
else
:
_C
.
model
=
DeepSpeech2ModelOnline
.
params
()
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
...
...
deepspeech/exps/deepspeech2/model.py
浏览文件 @
38d95784
...
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 model."""
"""Contains DeepSpeech2
and DeepSpeech2Online
model."""
import
time
from
collections
import
defaultdict
from
pathlib
import
Path
...
...
@@ -29,6 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from
deepspeech.io.sampler
import
SortagradDistributedBatchSampler
from
deepspeech.models.ds2
import
DeepSpeech2InferModel
from
deepspeech.models.ds2
import
DeepSpeech2Model
from
deepspeech.models.ds2_online
import
DeepSpeech2InferModelOnline
from
deepspeech.models.ds2_online
import
DeepSpeech2ModelOnline
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.utils
import
error_rate
...
...
@@ -119,16 +121,22 @@ class DeepSpeech2Trainer(Trainer):
return
total_loss
,
num_seen_utts
def
setup_model
(
self
):
config
=
self
.
config
model
=
DeepSpeech2Model
(
feat_size
=
self
.
train_loader
.
collate_fn
.
feature_size
,
dict_size
=
self
.
train_loader
.
collate_fn
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
use_gru
=
config
.
model
.
use_gru
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
config
=
self
.
config
.
clone
()
config
.
defrost
()
assert
(
self
.
train_loader
.
collate_fn
.
feature_size
==
self
.
test_loader
.
collate_fn
.
feature_size
)
assert
(
self
.
train_loader
.
collate_fn
.
vocab_size
==
self
.
test_loader
.
collate_fn
.
vocab_size
)
config
.
model
.
feat_size
=
self
.
train_loader
.
collate_fn
.
feature_size
config
.
model
.
dict_size
=
self
.
train_loader
.
collate_fn
.
vocab_size
config
.
freeze
()
if
self
.
args
.
model_type
==
'offline'
:
model
=
DeepSpeech2Model
.
from_config
(
config
.
model
)
elif
self
.
args
.
model_type
==
'online'
:
model
=
DeepSpeech2ModelOnline
.
from_config
(
config
.
model
)
else
:
raise
Exception
(
"wrong model type"
)
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
...
...
@@ -164,6 +172,9 @@ class DeepSpeech2Trainer(Trainer):
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
dev_dataset
=
ManifestDataset
.
from_config
(
config
)
config
.
data
.
manifest
=
config
.
data
.
test_manifest
test_dataset
=
ManifestDataset
.
from_config
(
config
)
if
self
.
parallel
:
batch_sampler
=
SortagradDistributedBatchSampler
(
train_dataset
,
...
...
@@ -187,6 +198,11 @@ class DeepSpeech2Trainer(Trainer):
config
.
collator
.
augmentation_config
=
""
collate_fn_dev
=
SpeechCollator
.
from_config
(
config
)
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
augmentation_config
=
""
collate_fn_test
=
SpeechCollator
.
from_config
(
config
)
self
.
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
batch_sampler
,
...
...
@@ -198,7 +214,13 @@ class DeepSpeech2Trainer(Trainer):
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_dev
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
self
.
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_test
)
logger
.
info
(
"Setup train/valid/test Dataloader!"
)
class
DeepSpeech2Tester
(
DeepSpeech2Trainer
):
...
...
@@ -329,19 +351,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit
(
-
1
)
def
export
(
self
):
infer_model
=
DeepSpeech2InferModel
.
from_pretrained
(
self
.
test_loader
,
self
.
config
,
self
.
args
.
checkpoint_path
)
if
self
.
args
.
model_type
==
'offline'
:
infer_model
=
DeepSpeech2InferModel
.
from_pretrained
(
self
.
test_loader
,
self
.
config
,
self
.
args
.
checkpoint_path
)
elif
self
.
args
.
model_type
==
'online'
:
infer_model
=
DeepSpeech2InferModelOnline
.
from_pretrained
(
self
.
test_loader
,
self
.
config
,
self
.
args
.
checkpoint_path
)
else
:
raise
Exception
(
"wrong model type"
)
infer_model
.
eval
()
feat_dim
=
self
.
test_loader
.
collate_fn
.
feature_size
static_model
=
paddle
.
jit
.
to_static
(
infer_model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
feat_dim
],
dtype
=
'float32'
),
# audio, [B,T,D]
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
])
static_model
=
infer_model
.
export
()
logger
.
info
(
f
"Export code:
{
static_model
.
forward
.
code
}
"
)
paddle
.
jit
.
save
(
static_model
,
self
.
args
.
export_path
)
...
...
@@ -365,46 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self
.
iteration
=
0
self
.
epoch
=
0
def
setup_model
(
self
):
config
=
self
.
config
model
=
DeepSpeech2Model
(
feat_size
=
self
.
test_loader
.
collate_fn
.
feature_size
,
dict_size
=
self
.
test_loader
.
collate_fn
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
use_gru
=
config
.
model
.
use_gru
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
self
.
model
=
model
logger
.
info
(
"Setup model!"
)
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
.
defrost
()
# return raw text
config
.
data
.
manifest
=
config
.
data
.
test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
# config.data.max_input_len = float('inf') # second
# config.data.min_output_len = 0.0 # tokens
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset
=
ManifestDataset
.
from_config
(
config
)
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
augmentation_config
=
""
# return text ord id
self
.
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
SpeechCollator
.
from_config
(
config
))
logger
.
info
(
"Setup test Dataloader!"
)
def
setup_output_dir
(
self
):
"""Create a directory used for output.
"""
...
...
deepspeech/io/sampler.py
浏览文件 @
38d95784
...
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng
=
np
.
random
.
RandomState
(
epoch
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
rng
.
shuffle
(
batch_indices
)
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
assert
clipped
is
False
...
...
deepspeech/models/ds2/deepspeech2.py
浏览文件 @
38d95784
...
...
@@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer):
layer_tools
.
summary
(
model
)
return
model
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
use_gru
=
config
.
use_gru
,
share_rnn_weights
=
config
.
share_rnn_weights
)
return
model
class
DeepSpeech2InferModel
(
DeepSpeech2Model
):
def
__init__
(
self
,
...
...
@@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
eouts
,
eouts_len
=
self
.
encoder
(
audio
,
audio_len
)
probs
=
self
.
decoder
.
softmax
(
eouts
)
return
probs
def
export
(
self
):
static_model
=
paddle
.
jit
.
to_static
(
self
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
self
.
encoder
.
feat_size
],
dtype
=
'float32'
),
# audio, [B,T,D]
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
])
return
static_model
deepspeech/models/ds2_online/__init__.py
0 → 100644
浏览文件 @
38d95784
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.deepspeech2
import
DeepSpeech2InferModelOnline
from
.deepspeech2
import
DeepSpeech2ModelOnline
__all__
=
[
'DeepSpeech2ModelOnline'
,
'DeepSpeech2InferModelOnline'
]
deepspeech/models/ds2_online/conv.py
0 → 100644
浏览文件 @
38d95784
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
paddle
import
nn
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.subsampling
import
Conv2dSubsampling4
class
Conv2dSubsampling4Online
(
Conv2dSubsampling4
):
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
):
super
().
__init__
(
idim
,
odim
,
dropout_rate
,
None
)
self
.
output_dim
=
((
idim
-
1
)
//
2
-
1
)
//
2
*
odim
self
.
receptive_field_length
=
2
*
(
3
-
1
)
+
3
# stride_1 * (kernel_size_2 - 1) + kerel_size_1
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_len
:
paddle
.
Tensor
)
->
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
x
=
x
.
unsqueeze
(
1
)
# (b, c=1, t, f)
x
=
self
.
conv
(
x
)
#b, c, t, f = paddle.shape(x) #not work under jit
x
=
x
.
transpose
([
0
,
2
,
1
,
3
]).
reshape
([
0
,
0
,
-
1
])
x_len
=
((
x_len
-
1
)
//
2
-
1
)
//
2
return
x
,
x_len
deepspeech/models/ds2_online/deepspeech2.py
0 → 100644
浏览文件 @
38d95784
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
from
typing
import
Optional
import
paddle
import
paddle.nn.functional
as
F
from
paddle
import
nn
from
yacs.config
import
CfgNode
from
deepspeech.models.ds2_online.conv
import
Conv2dSubsampling4Online
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.utils
import
layer_tools
from
deepspeech.utils.checkpoint
import
Checkpoint
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
'DeepSpeech2ModelOnline'
,
'DeepSpeech2InferModeOnline'
]
class
CRNNEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
):
super
().
__init__
()
self
.
rnn_size
=
rnn_size
self
.
feat_size
=
feat_size
# 161 for linear
self
.
dict_size
=
dict_size
self
.
num_rnn_layers
=
num_rnn_layers
self
.
num_fc_layers
=
num_fc_layers
self
.
rnn_direction
=
rnn_direction
self
.
fc_layers_size_list
=
fc_layers_size_list
self
.
use_gru
=
use_gru
self
.
conv
=
Conv2dSubsampling4Online
(
feat_size
,
32
,
dropout_rate
=
0.0
)
self
.
output_dim
=
self
.
conv
.
output_dim
i_size
=
self
.
conv
.
output_dim
self
.
rnn
=
nn
.
LayerList
()
self
.
layernorm_list
=
nn
.
LayerList
()
self
.
fc_layers_list
=
nn
.
LayerList
()
if
rnn_direction
==
'bidirect'
or
rnn_direction
==
'bidirectional'
:
layernorm_size
=
2
*
rnn_size
elif
rnn_direction
==
'forward'
:
layernorm_size
=
rnn_size
else
:
raise
Exception
(
"Wrong rnn direction"
)
for
i
in
range
(
0
,
num_rnn_layers
):
if
i
==
0
:
rnn_input_size
=
i_size
else
:
rnn_input_size
=
layernorm_size
if
use_gru
==
True
:
self
.
rnn
.
append
(
nn
.
GRU
(
input_size
=
rnn_input_size
,
hidden_size
=
rnn_size
,
num_layers
=
1
,
direction
=
rnn_direction
))
else
:
self
.
rnn
.
append
(
nn
.
LSTM
(
input_size
=
rnn_input_size
,
hidden_size
=
rnn_size
,
num_layers
=
1
,
direction
=
rnn_direction
))
self
.
layernorm_list
.
append
(
nn
.
LayerNorm
(
layernorm_size
))
self
.
output_dim
=
layernorm_size
fc_input_size
=
layernorm_size
for
i
in
range
(
self
.
num_fc_layers
):
self
.
fc_layers_list
.
append
(
nn
.
Linear
(
fc_input_size
,
fc_layers_size_list
[
i
]))
fc_input_size
=
fc_layers_size_list
[
i
]
self
.
output_dim
=
fc_layers_size_list
[
i
]
@
property
def
output_size
(
self
):
return
self
.
output_dim
def
forward
(
self
,
x
,
x_lens
,
init_state_h_box
=
None
,
init_state_c_box
=
None
):
"""Compute Encoder outputs
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
if
init_state_h_box
is
not
None
:
init_state_list
=
None
if
self
.
use_gru
==
True
:
init_state_h_list
=
paddle
.
split
(
init_state_h_box
,
self
.
num_rnn_layers
,
axis
=
0
)
init_state_list
=
init_state_h_list
else
:
init_state_h_list
=
paddle
.
split
(
init_state_h_box
,
self
.
num_rnn_layers
,
axis
=
0
)
init_state_c_list
=
paddle
.
split
(
init_state_c_box
,
self
.
num_rnn_layers
,
axis
=
0
)
init_state_list
=
[(
init_state_h_list
[
i
],
init_state_c_list
[
i
])
for
i
in
range
(
self
.
num_rnn_layers
)]
else
:
init_state_list
=
[
None
]
*
self
.
num_rnn_layers
x
,
x_lens
=
self
.
conv
(
x
,
x_lens
)
final_chunk_state_list
=
[]
for
i
in
range
(
0
,
self
.
num_rnn_layers
):
x
,
final_state
=
self
.
rnn
[
i
](
x
,
init_state_list
[
i
],
x_lens
)
#[B, T, D]
final_chunk_state_list
.
append
(
final_state
)
x
=
self
.
layernorm_list
[
i
](
x
)
for
i
in
range
(
self
.
num_fc_layers
):
x
=
self
.
fc_layers_list
[
i
](
x
)
x
=
F
.
relu
(
x
)
if
self
.
use_gru
==
True
:
final_chunk_state_h_box
=
paddle
.
concat
(
final_chunk_state_list
,
axis
=
0
)
final_chunk_state_c_box
=
init_state_c_box
#paddle.zeros_like(final_chunk_state_h_box)
else
:
final_chunk_state_h_list
=
[
final_chunk_state_list
[
i
][
0
]
for
i
in
range
(
self
.
num_rnn_layers
)
]
final_chunk_state_c_list
=
[
final_chunk_state_list
[
i
][
1
]
for
i
in
range
(
self
.
num_rnn_layers
)
]
final_chunk_state_h_box
=
paddle
.
concat
(
final_chunk_state_h_list
,
axis
=
0
)
final_chunk_state_c_box
=
paddle
.
concat
(
final_chunk_state_c_list
,
axis
=
0
)
return
x
,
x_lens
,
final_chunk_state_h_box
,
final_chunk_state_c_box
def
forward_chunk_by_chunk
(
self
,
x
,
x_lens
,
decoder_chunk_size
=
8
):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
subsampling_rate
=
self
.
conv
.
subsampling_rate
receptive_field_length
=
self
.
conv
.
receptive_field_length
chunk_size
=
(
decoder_chunk_size
-
1
)
*
subsampling_rate
+
receptive_field_length
chunk_stride
=
subsampling_rate
*
decoder_chunk_size
max_len
=
x
.
shape
[
1
]
assert
(
chunk_size
<=
max_len
)
eouts_chunk_list
=
[]
eouts_chunk_lens_list
=
[]
padding_len
=
chunk_stride
-
(
max_len
-
chunk_size
)
%
chunk_stride
padding
=
paddle
.
zeros
((
x
.
shape
[
0
],
padding_len
,
x
.
shape
[
2
]))
padded_x
=
paddle
.
concat
([
x
,
padding
],
axis
=
1
)
num_chunk
=
(
max_len
+
padding_len
-
chunk_size
)
/
chunk_stride
+
1
num_chunk
=
int
(
num_chunk
)
chunk_state_h_box
=
None
chunk_state_c_box
=
None
final_state_h_box
=
None
final_state_c_box
=
None
for
i
in
range
(
0
,
num_chunk
):
start
=
i
*
chunk_stride
end
=
start
+
chunk_size
x_chunk
=
padded_x
[:,
start
:
end
,
:]
x_len_left
=
paddle
.
where
(
x_lens
-
i
*
chunk_stride
<
0
,
paddle
.
zeros_like
(
x_lens
),
x_lens
-
i
*
chunk_stride
)
x_chunk_len_tmp
=
paddle
.
ones_like
(
x_lens
)
*
chunk_size
x_chunk_lens
=
paddle
.
where
(
x_len_left
<
x_chunk_len_tmp
,
x_len_left
,
x_chunk_len_tmp
)
eouts_chunk
,
eouts_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
=
self
.
forward
(
x_chunk
,
x_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
)
eouts_chunk_list
.
append
(
eouts_chunk
)
eouts_chunk_lens_list
.
append
(
eouts_chunk_lens
)
final_state_h_box
=
chunk_state_h_box
final_state_c_box
=
chunk_state_c_box
return
eouts_chunk_list
,
eouts_chunk_lens_list
,
final_state_h_box
,
final_state_c_box
class
DeepSpeech2ModelOnline
(
nn
.
Layer
):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param num_fc_layers: Number of stacking FC layers.
:type num_fc_layers: int
:param fc_layers_size_list: The list of FC layer sizes.
:type fc_layers_size_list: [int,]
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
default
=
CfgNode
(
dict
(
num_conv_layers
=
2
,
#Number of stacking convolution layers.
num_rnn_layers
=
4
,
#Number of stacking RNN layers.
rnn_layer_size
=
1024
,
#RNN layer size (number of RNN cells).
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
,
#Use gru if set True. Use simple rnn if set False.
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
):
super
().
__init__
()
self
.
encoder
=
CRNNEncoder
(
feat_size
=
feat_size
,
dict_size
=
dict_size
,
num_conv_layers
=
num_conv_layers
,
num_rnn_layers
=
num_rnn_layers
,
rnn_direction
=
rnn_direction
,
num_fc_layers
=
num_fc_layers
,
fc_layers_size_list
=
fc_layers_size_list
,
rnn_size
=
rnn_size
,
use_gru
=
use_gru
)
self
.
decoder
=
CTCDecoder
(
odim
=
dict_size
,
# <blank> is in vocab
enc_n_units
=
self
.
encoder
.
output_size
,
blank_id
=
0
,
# first token is <blank>
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
)
# sum / batch_size
def
forward
(
self
,
audio
,
audio_len
,
text
,
text_len
):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts
,
eouts_len
,
final_state_h_box
,
final_state_c_box
=
self
.
encoder
(
audio
,
audio_len
,
None
,
None
)
loss
=
self
.
decoder
(
eouts
,
eouts_len
,
text
,
text_len
)
return
loss
@
paddle
.
no_grad
()
def
decode
(
self
,
audio
,
audio_len
,
vocab_list
,
decoding_method
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
num_processes
):
# init once
# decoders only accept string encoded in utf-8
self
.
decoder
.
init_decode
(
beam_alpha
=
beam_alpha
,
beam_beta
=
beam_beta
,
lang_model_path
=
lang_model_path
,
vocab_list
=
vocab_list
,
decoding_method
=
decoding_method
)
eouts
,
eouts_len
,
final_state_h_box
,
final_state_c_box
=
self
.
encoder
(
audio
,
audio_len
,
None
,
None
)
probs
=
self
.
decoder
.
softmax
(
eouts
)
return
self
.
decoder
.
decode_probs
(
probs
.
numpy
(),
eouts_len
,
vocab_list
,
decoding_method
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
num_processes
)
@
classmethod
def
from_pretrained
(
cls
,
dataloader
,
config
,
checkpoint_path
):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
dict_size
=
dataloader
.
collate_fn
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
rnn_direction
=
config
.
model
.
rnn_direction
,
num_fc_layers
=
config
.
model
.
num_fc_layers
,
fc_layers_size_list
=
config
.
model
.
fc_layers_size_list
,
use_gru
=
config
.
model
.
use_gru
)
infos
=
Checkpoint
().
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
layer_tools
.
summary
(
model
)
return
model
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a DeepSpeec2ModelOnline from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2ModelOnline
The model built from config.
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
rnn_direction
=
config
.
rnn_direction
,
num_fc_layers
=
config
.
num_fc_layers
,
fc_layers_size_list
=
config
.
fc_layers_size_list
,
use_gru
=
config
.
use_gru
)
return
model
class
DeepSpeech2InferModelOnline
(
DeepSpeech2ModelOnline
):
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
):
super
().
__init__
(
feat_size
=
feat_size
,
dict_size
=
dict_size
,
num_conv_layers
=
num_conv_layers
,
num_rnn_layers
=
num_rnn_layers
,
rnn_size
=
rnn_size
,
rnn_direction
=
rnn_direction
,
num_fc_layers
=
num_fc_layers
,
fc_layers_size_list
=
fc_layers_size_list
,
use_gru
=
use_gru
)
def
forward
(
self
,
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
):
eouts_chunk
,
eouts_chunk_lens
,
final_state_h_box
,
final_state_c_box
=
self
.
encoder
(
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
)
probs_chunk
=
self
.
decoder
.
softmax
(
eouts_chunk
)
return
probs_chunk
,
eouts_chunk_lens
,
final_state_h_box
,
final_state_c_box
def
export
(
self
):
static_model
=
paddle
.
jit
.
to_static
(
self
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
self
.
encoder
.
feat_size
],
#[B, chunk_size, feat_dim]
dtype
=
'float32'
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
None
],
dtype
=
'float32'
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
None
],
dtype
=
'float32'
)
])
return
static_model
deepspeech/modules/subsampling.py
浏览文件 @
38d95784
...
...
@@ -92,7 +92,7 @@ class Conv2dSubsampling4(BaseSubsampling):
dropout_rate
:
float
,
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
"""Construct an Conv2dSubsampling4 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
...
...
@@ -143,7 +143,7 @@ class Conv2dSubsampling6(BaseSubsampling):
dropout_rate
:
float
,
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
"""Construct an Conv2dSubsampling6 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
...
...
@@ -196,7 +196,7 @@ class Conv2dSubsampling8(BaseSubsampling):
dropout_rate
:
float
,
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
"""Construct an Conv2dSubsampling8 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
...
...
env.sh
浏览文件 @
38d95784
...
...
@@ -4,7 +4,7 @@ export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export
LC_ALL
=
C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export
PYTHONIOENCODING
=
UTF-8
export
PYTHONIOENCODING
=
UTF-8
export
PYTHONPATH
=
${
MAIN_ROOT
}
:
${
PYTHONPATH
}
export
LD_LIBRARY_PATH
=
${
LD_LIBRARY_PATH
}
:/usr/local/lib/
examples/aishell/s0/conf/deepspeech2_online.yaml
0 → 100644
浏览文件 @
38d95784
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
min_input_len
:
0.0
max_input_len
:
27.0
# second
min_output_len
:
0.0
max_output_len
:
.inf
min_output_input_ratio
:
0.00
max_output_input_ratio
:
.inf
collator
:
batch_size
:
32
# one gpu
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
random_seed
:
0
spm_model_prefix
:
specgram_type
:
linear
#linear, mfcc, fbank
feat_dim
:
delta_delta
:
False
stride_ms
:
10.0
window_ms
:
20.0
n_fft
:
None
max_freq
:
None
target_sample_rate
:
16000
use_dB_normalization
:
True
target_dB
:
-20
dither
:
1.0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
0
model
:
num_conv_layers
:
2
num_rnn_layers
:
3
rnn_layer_size
:
1024
rnn_direction
:
forward
# [forward, bidirect]
num_fc_layers
:
1
fc_layers_size_list
:
512,
use_gru
:
True
training
:
n_epoch
:
50
lr
:
2e-3
lr_decay
:
0.83
# 0.83
weight_decay
:
1e-06
global_grad_clip
:
3.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
batch_size
:
32
error_rate_type
:
cer
decoding_method
:
ctc_beam_search
lang_model_path
:
data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha
:
1.9
beta
:
5.0
beam_size
:
300
cutoff_prob
:
0.99
cutoff_top_n
:
40
num_proc_bsearch
:
10
examples/aishell/s0/local/export.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
3
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
if
[
$#
!=
4
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path
model_type
"
exit
-1
fi
...
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
model_type
=
$4
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
...
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--export_path
${
jit_model_export_path
}
--export_path
${
jit_model_export_path
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in export!"
...
...
examples/aishell/s0/local/test.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix
model_type
"
exit
-1
fi
...
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
config_path
=
$1
ckpt_prefix
=
$2
model_type
=
$3
# download language model
bash
local
/download_lm_ch.sh
...
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc
1
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
...
...
examples/aishell/s0/local/train.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
if
[
$#
!=
3
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name
model_type
"
exit
-1
fi
...
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
ckpt_name
=
$2
model_type
=
$3
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
...
...
@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/aishell/s0/run.sh
浏览文件 @
38d95784
...
...
@@ -7,6 +7,7 @@ stage=0
stop_stage
=
100
conf_path
=
conf/deepspeech2.yaml
avg_num
=
1
model_type
=
offline
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
...
@@ -21,7 +22,7 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
${
model_type
}
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
...
...
@@ -31,10 +32,10 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
model_type
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
CUDA_VISIBLE_DEVICES
=
0 ./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
fi
examples/librispeech/s0/conf/deepspeech2_online.yaml
0 → 100644
浏览文件 @
38d95784
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev-clean
test_manifest
:
data/manifest.test-clean
min_input_len
:
0.0
max_input_len
:
27.0
# second
min_output_len
:
0.0
max_output_len
:
.inf
min_output_input_ratio
:
0.00
max_output_input_ratio
:
.inf
collator
:
batch_size
:
20
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
random_seed
:
0
spm_model_prefix
:
specgram_type
:
linear
target_sample_rate
:
16000
max_freq
:
None
n_fft
:
None
stride_ms
:
10.0
window_ms
:
20.0
delta_delta
:
False
dither
:
1.0
use_dB_normalization
:
True
target_dB
:
-20
random_seed
:
0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
0
model
:
num_conv_layers
:
2
num_rnn_layers
:
3
rnn_layer_size
:
2048
rnn_direction
:
forward
num_fc_layers
:
2
fc_layers_size_list
:
512,
256
use_gru
:
False
training
:
n_epoch
:
50
lr
:
1e-3
lr_decay
:
0.83
weight_decay
:
1e-06
global_grad_clip
:
5.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
batch_size
:
128
error_rate_type
:
wer
decoding_method
:
ctc_beam_search
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
alpha
:
1.9
beta
:
0.3
beam_size
:
500
cutoff_prob
:
1.0
cutoff_top_n
:
40
num_proc_bsearch
:
8
examples/librispeech/s0/local/export.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
3
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
if
[
$#
!=
4
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path
model_type
"
exit
-1
fi
...
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
model_type
=
$4
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
...
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--export_path
${
jit_model_export_path
}
--export_path
${
jit_model_export_path
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in export!"
...
...
examples/librispeech/s0/local/test.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix
model_type
"
exit
-1
fi
...
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
config_path
=
$1
ckpt_prefix
=
$2
model_type
=
$3
# download language model
bash
local
/download_lm_en.sh
...
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc
1
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
...
...
examples/librispeech/s0/local/train.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
if
[
$#
!=
3
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name
model_type
"
exit
-1
fi
...
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
ckpt_name
=
$2
model_type
=
$3
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
...
...
@@ -23,7 +24,8 @@ python3 -u ${BIN_DIR}/train.py \
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/librispeech/s0/run.sh
浏览文件 @
38d95784
...
...
@@ -6,6 +6,7 @@ stage=0
stop_stage
=
100
conf_path
=
conf/deepspeech2.yaml
avg_num
=
30
model_type
=
offline
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
avg_ckpt
=
avg_
${
avg_num
}
...
...
@@ -19,7 +20,7 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7 ./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7 ./local/train.sh
${
conf_path
}
${
ckpt
}
${
model_type
}
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
...
...
@@ -29,10 +30,10 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
7 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
7 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
model_type
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
CUDA_VISIBLE_DEVICES
=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
fi
examples/tiny/s0/conf/deepspeech2_online.yaml
0 → 100644
浏览文件 @
38d95784
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.tiny
dev_manifest
:
data/manifest.tiny
test_manifest
:
data/manifest.tiny
min_input_len
:
0.0
max_input_len
:
27.0
min_output_len
:
0.0
max_output_len
:
400.0
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
collator
:
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
random_seed
:
0
spm_model_prefix
:
specgram_type
:
linear
feat_dim
:
delta_delta
:
False
stride_ms
:
10.0
window_ms
:
20.0
n_fft
:
None
max_freq
:
None
target_sample_rate
:
16000
use_dB_normalization
:
True
target_dB
:
-20
dither
:
1.0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
0
batch_size
:
4
model
:
num_conv_layers
:
2
num_rnn_layers
:
4
rnn_layer_size
:
2048
rnn_direction
:
forward
num_fc_layers
:
2
fc_layers_size_list
:
512,
256
use_gru
:
True
training
:
n_epoch
:
10
lr
:
1e-5
lr_decay
:
1.0
weight_decay
:
1e-06
global_grad_clip
:
5.0
log_interval
:
1
checkpoint
:
kbest_n
:
3
latest_n
:
2
decoding
:
batch_size
:
128
error_rate_type
:
wer
decoding_method
:
ctc_beam_search
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
alpha
:
2.5
beta
:
0.3
beam_size
:
500
cutoff_prob
:
1.0
cutoff_top_n
:
40
num_proc_bsearch
:
8
examples/tiny/s0/local/export.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
3
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
if
[
$#
!=
4
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path
model_type
"
exit
-1
fi
...
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
model_type
=
$4
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
...
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--export_path
${
jit_model_export_path
}
--export_path
${
jit_model_export_path
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in export!"
...
...
examples/tiny/s0/local/test.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix
model_type
"
exit
-1
fi
...
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
config_path
=
$1
ckpt_prefix
=
$2
model_type
=
$3
# download language model
bash
local
/download_lm_en.sh
...
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc
1
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
...
...
examples/tiny/s0/local/train.sh
浏览文件 @
38d95784
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
if
[
$#
!=
3
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name
model_type
"
exit
-1
fi
...
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
ckpt_name
=
$2
model_type
=
$3
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
...
...
@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/tiny/s0/run.sh
浏览文件 @
38d95784
...
...
@@ -7,6 +7,7 @@ stage=0
stop_stage
=
100
conf_path
=
conf/deepspeech2.yaml
avg_num
=
1
model_type
=
offline
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
...
@@ -21,7 +22,7 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
${
model_type
}
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
...
...
@@ -31,10 +32,10 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
model_type
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
fi
tests/deepspeech2_model_test.py
浏览文件 @
38d95784
...
...
@@ -16,7 +16,7 @@ import unittest
import
numpy
as
np
import
paddle
from
deepspeech.models.d
eepspeech
2
import
DeepSpeech2Model
from
deepspeech.models.d
s
2
import
DeepSpeech2Model
class
TestDeepSpeech2Model
(
unittest
.
TestCase
):
...
...
tests/deepspeech2_online_model_test.py
0 → 100644
浏览文件 @
38d95784
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
from
deepspeech.models.ds2_online
import
DeepSpeech2ModelOnline
class
TestDeepSpeech2ModelOnline
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
set_device
(
'cpu'
)
self
.
batch_size
=
2
self
.
feat_dim
=
161
max_len
=
210
# (B, T, D)
audio
=
np
.
random
.
randn
(
self
.
batch_size
,
max_len
,
self
.
feat_dim
)
audio_len
=
np
.
random
.
randint
(
max_len
,
size
=
self
.
batch_size
)
audio_len
[
-
1
]
=
max_len
# (B, U)
text
=
np
.
array
([[
1
,
2
],
[
1
,
2
]])
text_len
=
np
.
array
([
2
]
*
self
.
batch_size
)
self
.
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
self
.
audio_len
=
paddle
.
to_tensor
(
audio_len
,
dtype
=
'int64'
)
self
.
text
=
paddle
.
to_tensor
(
text
,
dtype
=
'int32'
)
self
.
text_len
=
paddle
.
to_tensor
(
text_len
,
dtype
=
'int64'
)
def
test_ds2_1
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_2
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_3
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_4
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_5
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_6
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
rnn_direction
=
'bidirect'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_7
(
self
):
use_gru
=
False
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
1
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
use_gru
)
model
.
eval
()
paddle
.
device
.
set_device
(
"cpu"
)
de_ch_size
=
8
eouts
,
eouts_lens
,
final_state_h_box
,
final_state_c_box
=
model
.
encoder
(
self
.
audio
,
self
.
audio_len
)
eouts_by_chk_list
,
eouts_lens_by_chk_list
,
final_state_h_box_chk
,
final_state_c_box_chk
=
model
.
encoder
.
forward_chunk_by_chunk
(
self
.
audio
,
self
.
audio_len
,
de_ch_size
)
eouts_by_chk
=
paddle
.
concat
(
eouts_by_chk_list
,
axis
=
1
)
eouts_lens_by_chk
=
paddle
.
add_n
(
eouts_lens_by_chk_list
)
decode_max_len
=
eouts
.
shape
[
1
]
eouts_by_chk
=
eouts_by_chk
[:,
:
decode_max_len
,
:]
self
.
assertEqual
(
paddle
.
allclose
(
eouts_by_chk
,
eouts
),
True
)
self
.
assertEqual
(
paddle
.
allclose
(
final_state_h_box
,
final_state_h_box_chk
),
True
)
if
use_gru
==
False
:
self
.
assertEqual
(
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
def
test_ds2_8
(
self
):
use_gru
=
True
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
1
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
use_gru
)
model
.
eval
()
paddle
.
device
.
set_device
(
"cpu"
)
de_ch_size
=
8
eouts
,
eouts_lens
,
final_state_h_box
,
final_state_c_box
=
model
.
encoder
(
self
.
audio
,
self
.
audio_len
)
eouts_by_chk_list
,
eouts_lens_by_chk_list
,
final_state_h_box_chk
,
final_state_c_box_chk
=
model
.
encoder
.
forward_chunk_by_chunk
(
self
.
audio
,
self
.
audio_len
,
de_ch_size
)
eouts_by_chk
=
paddle
.
concat
(
eouts_by_chk_list
,
axis
=
1
)
eouts_lens_by_chk
=
paddle
.
add_n
(
eouts_lens_by_chk_list
)
decode_max_len
=
eouts
.
shape
[
1
]
eouts_by_chk
=
eouts_by_chk
[:,
:
decode_max_len
,
:]
self
.
assertEqual
(
paddle
.
allclose
(
eouts_by_chk
,
eouts
),
True
)
self
.
assertEqual
(
paddle
.
allclose
(
final_state_h_box
,
final_state_h_box_chk
),
True
)
if
use_gru
==
False
:
self
.
assertEqual
(
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录