Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
38d95784
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
38d95784
编写于
8月 10, 2021
作者:
H
Hui Zhang
提交者:
GitHub
8月 10, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #735 from Jackwaterveg/ds2_online
Ds2 online
上级
99050b70
61fe292c
变更
29
隐藏空白更改
内联
并排
Showing
29 changed file
with
1036 addition
and
126 deletion
+1036
-126
deepspeech/exps/deepspeech2/bin/export.py
deepspeech/exps/deepspeech2/bin/export.py
+5
-1
deepspeech/exps/deepspeech2/bin/test.py
deepspeech/exps/deepspeech2/bin/test.py
+5
-1
deepspeech/exps/deepspeech2/bin/train.py
deepspeech/exps/deepspeech2/bin/train.py
+5
-1
deepspeech/exps/deepspeech2/config.py
deepspeech/exps/deepspeech2/config.py
+13
-15
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+44
-63
deepspeech/io/sampler.py
deepspeech/io/sampler.py
+1
-1
deepspeech/models/ds2/deepspeech2.py
deepspeech/models/ds2/deepspeech2.py
+33
-0
deepspeech/models/ds2_online/__init__.py
deepspeech/models/ds2_online/__init__.py
+17
-0
deepspeech/models/ds2_online/conv.py
deepspeech/models/ds2_online/conv.py
+35
-0
deepspeech/models/ds2_online/deepspeech2.py
deepspeech/models/ds2_online/deepspeech2.py
+427
-0
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+3
-3
env.sh
env.sh
+1
-1
examples/aishell/s0/conf/deepspeech2_online.yaml
examples/aishell/s0/conf/deepspeech2_online.yaml
+67
-0
examples/aishell/s0/local/export.sh
examples/aishell/s0/local/export.sh
+5
-4
examples/aishell/s0/local/test.sh
examples/aishell/s0/local/test.sh
+5
-3
examples/aishell/s0/local/train.sh
examples/aishell/s0/local/train.sh
+5
-3
examples/aishell/s0/run.sh
examples/aishell/s0/run.sh
+4
-3
examples/librispeech/s0/conf/deepspeech2_online.yaml
examples/librispeech/s0/conf/deepspeech2_online.yaml
+67
-0
examples/librispeech/s0/local/export.sh
examples/librispeech/s0/local/export.sh
+5
-4
examples/librispeech/s0/local/test.sh
examples/librispeech/s0/local/test.sh
+5
-3
examples/librispeech/s0/local/train.sh
examples/librispeech/s0/local/train.sh
+5
-3
examples/librispeech/s0/run.sh
examples/librispeech/s0/run.sh
+4
-3
examples/tiny/s0/conf/deepspeech2_online.yaml
examples/tiny/s0/conf/deepspeech2_online.yaml
+69
-0
examples/tiny/s0/local/export.sh
examples/tiny/s0/local/export.sh
+5
-4
examples/tiny/s0/local/test.sh
examples/tiny/s0/local/test.sh
+5
-3
examples/tiny/s0/local/train.sh
examples/tiny/s0/local/train.sh
+5
-3
examples/tiny/s0/run.sh
examples/tiny/s0/run.sh
+4
-3
tests/deepspeech2_model_test.py
tests/deepspeech2_model_test.py
+1
-1
tests/deepspeech2_online_model_test.py
tests/deepspeech2_online_model_test.py
+186
-0
未找到文件。
deepspeech/exps/deepspeech2/bin/export.py
浏览文件 @
38d95784
...
@@ -30,11 +30,15 @@ def main(config, args):
...
@@ -30,11 +30,15 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
model_type
is
None
:
args
.
model_type
=
'offline'
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
print_arguments
(
args
)
print_arguments
(
args
)
# https://yaml.org/type/float.html
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
(
args
.
model_type
)
if
args
.
config
:
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
if
args
.
opts
:
...
...
deepspeech/exps/deepspeech2/bin/test.py
浏览文件 @
38d95784
...
@@ -30,11 +30,15 @@ def main(config, args):
...
@@ -30,11 +30,15 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
if
args
.
model_type
is
None
:
args
.
model_type
=
'offline'
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
# https://yaml.org/type/float.html
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
(
args
.
model_type
)
if
args
.
config
:
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
if
args
.
opts
:
...
...
deepspeech/exps/deepspeech2/bin/train.py
浏览文件 @
38d95784
...
@@ -35,11 +35,15 @@ def main(config, args):
...
@@ -35,11 +35,15 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
model_type
is
None
:
args
.
model_type
=
'offline'
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
# https://yaml.org/type/float.html
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
(
args
.
model_type
)
if
args
.
config
:
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
if
args
.
opts
:
...
...
deepspeech/exps/deepspeech2/config.py
浏览文件 @
38d95784
...
@@ -18,21 +18,19 @@ from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
...
@@ -18,21 +18,19 @@ from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from
deepspeech.io.collator
import
SpeechCollator
from
deepspeech.io.collator
import
SpeechCollator
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.models.ds2
import
DeepSpeech2Model
from
deepspeech.models.ds2
import
DeepSpeech2Model
from
deepspeech.models.ds2_online
import
DeepSpeech2ModelOnline
_C
=
CfgNode
()
_C
.
data
=
ManifestDataset
.
params
()
def
get_cfg_defaults
(
model_type
=
'offline'
):
_C
=
CfgNode
()
_C
.
collator
=
SpeechCollator
.
params
()
_C
.
data
=
ManifestDataset
.
params
()
_C
.
collator
=
SpeechCollator
.
params
()
_C
.
model
=
DeepSpeech2Model
.
params
()
_C
.
training
=
DeepSpeech2Trainer
.
params
()
_C
.
decoding
=
DeepSpeech2Tester
.
params
()
_C
.
training
=
DeepSpeech2Trainer
.
params
()
if
model_type
==
'offline'
:
_C
.
model
=
DeepSpeech2Model
.
params
()
_C
.
decoding
=
DeepSpeech2Tester
.
params
()
else
:
_C
.
model
=
DeepSpeech2ModelOnline
.
params
()
def
get_cfg_defaults
():
"""Get a yacs CfgNode object with default values for my_project."""
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
# This is for the "local variable" use pattern
...
...
deepspeech/exps/deepspeech2/model.py
浏览文件 @
38d95784
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Contains DeepSpeech2 model."""
"""Contains DeepSpeech2
and DeepSpeech2Online
model."""
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -29,6 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
...
@@ -29,6 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from
deepspeech.io.sampler
import
SortagradDistributedBatchSampler
from
deepspeech.io.sampler
import
SortagradDistributedBatchSampler
from
deepspeech.models.ds2
import
DeepSpeech2InferModel
from
deepspeech.models.ds2
import
DeepSpeech2InferModel
from
deepspeech.models.ds2
import
DeepSpeech2Model
from
deepspeech.models.ds2
import
DeepSpeech2Model
from
deepspeech.models.ds2_online
import
DeepSpeech2InferModelOnline
from
deepspeech.models.ds2_online
import
DeepSpeech2ModelOnline
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.utils
import
error_rate
from
deepspeech.utils
import
error_rate
...
@@ -119,16 +121,22 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -119,16 +121,22 @@ class DeepSpeech2Trainer(Trainer):
return
total_loss
,
num_seen_utts
return
total_loss
,
num_seen_utts
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
.
clone
()
model
=
DeepSpeech2Model
(
config
.
defrost
()
feat_size
=
self
.
train_loader
.
collate_fn
.
feature_size
,
assert
(
self
.
train_loader
.
collate_fn
.
feature_size
==
dict_size
=
self
.
train_loader
.
collate_fn
.
vocab_size
,
self
.
test_loader
.
collate_fn
.
feature_size
)
num_conv_layers
=
config
.
model
.
num_conv_layers
,
assert
(
self
.
train_loader
.
collate_fn
.
vocab_size
==
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
self
.
test_loader
.
collate_fn
.
vocab_size
)
rnn_size
=
config
.
model
.
rnn_layer_size
,
config
.
model
.
feat_size
=
self
.
train_loader
.
collate_fn
.
feature_size
use_gru
=
config
.
model
.
use_gru
,
config
.
model
.
dict_size
=
self
.
train_loader
.
collate_fn
.
vocab_size
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
config
.
freeze
()
if
self
.
args
.
model_type
==
'offline'
:
model
=
DeepSpeech2Model
.
from_config
(
config
.
model
)
elif
self
.
args
.
model_type
==
'online'
:
model
=
DeepSpeech2ModelOnline
.
from_config
(
config
.
model
)
else
:
raise
Exception
(
"wrong model type"
)
if
self
.
parallel
:
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
...
@@ -164,6 +172,9 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -164,6 +172,9 @@ class DeepSpeech2Trainer(Trainer):
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
dev_dataset
=
ManifestDataset
.
from_config
(
config
)
dev_dataset
=
ManifestDataset
.
from_config
(
config
)
config
.
data
.
manifest
=
config
.
data
.
test_manifest
test_dataset
=
ManifestDataset
.
from_config
(
config
)
if
self
.
parallel
:
if
self
.
parallel
:
batch_sampler
=
SortagradDistributedBatchSampler
(
batch_sampler
=
SortagradDistributedBatchSampler
(
train_dataset
,
train_dataset
,
...
@@ -187,6 +198,11 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -187,6 +198,11 @@ class DeepSpeech2Trainer(Trainer):
config
.
collator
.
augmentation_config
=
""
config
.
collator
.
augmentation_config
=
""
collate_fn_dev
=
SpeechCollator
.
from_config
(
config
)
collate_fn_dev
=
SpeechCollator
.
from_config
(
config
)
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
augmentation_config
=
""
collate_fn_test
=
SpeechCollator
.
from_config
(
config
)
self
.
train_loader
=
DataLoader
(
self
.
train_loader
=
DataLoader
(
train_dataset
,
train_dataset
,
batch_sampler
=
batch_sampler
,
batch_sampler
=
batch_sampler
,
...
@@ -198,7 +214,13 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -198,7 +214,13 @@ class DeepSpeech2Trainer(Trainer):
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_dev
)
collate_fn
=
collate_fn_dev
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
self
.
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_test
)
logger
.
info
(
"Setup train/valid/test Dataloader!"
)
class
DeepSpeech2Tester
(
DeepSpeech2Trainer
):
class
DeepSpeech2Tester
(
DeepSpeech2Trainer
):
...
@@ -329,19 +351,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -329,19 +351,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit
(
-
1
)
exit
(
-
1
)
def
export
(
self
):
def
export
(
self
):
infer_model
=
DeepSpeech2InferModel
.
from_pretrained
(
if
self
.
args
.
model_type
==
'offline'
:
self
.
test_loader
,
self
.
config
,
self
.
args
.
checkpoint_path
)
infer_model
=
DeepSpeech2InferModel
.
from_pretrained
(
self
.
test_loader
,
self
.
config
,
self
.
args
.
checkpoint_path
)
elif
self
.
args
.
model_type
==
'online'
:
infer_model
=
DeepSpeech2InferModelOnline
.
from_pretrained
(
self
.
test_loader
,
self
.
config
,
self
.
args
.
checkpoint_path
)
else
:
raise
Exception
(
"wrong model type"
)
infer_model
.
eval
()
infer_model
.
eval
()
feat_dim
=
self
.
test_loader
.
collate_fn
.
feature_size
feat_dim
=
self
.
test_loader
.
collate_fn
.
feature_size
static_model
=
paddle
.
jit
.
to_static
(
static_model
=
infer_model
.
export
()
infer_model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
feat_dim
],
dtype
=
'float32'
),
# audio, [B,T,D]
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
])
logger
.
info
(
f
"Export code:
{
static_model
.
forward
.
code
}
"
)
logger
.
info
(
f
"Export code:
{
static_model
.
forward
.
code
}
"
)
paddle
.
jit
.
save
(
static_model
,
self
.
args
.
export_path
)
paddle
.
jit
.
save
(
static_model
,
self
.
args
.
export_path
)
...
@@ -365,46 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -365,46 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self
.
iteration
=
0
self
.
iteration
=
0
self
.
epoch
=
0
self
.
epoch
=
0
def
setup_model
(
self
):
config
=
self
.
config
model
=
DeepSpeech2Model
(
feat_size
=
self
.
test_loader
.
collate_fn
.
feature_size
,
dict_size
=
self
.
test_loader
.
collate_fn
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
use_gru
=
config
.
model
.
use_gru
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
self
.
model
=
model
logger
.
info
(
"Setup model!"
)
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
.
defrost
()
# return raw text
config
.
data
.
manifest
=
config
.
data
.
test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
# config.data.max_input_len = float('inf') # second
# config.data.min_output_len = 0.0 # tokens
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset
=
ManifestDataset
.
from_config
(
config
)
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
augmentation_config
=
""
# return text ord id
self
.
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
SpeechCollator
.
from_config
(
config
))
logger
.
info
(
"Setup test Dataloader!"
)
def
setup_output_dir
(
self
):
def
setup_output_dir
(
self
):
"""Create a directory used for output.
"""Create a directory used for output.
"""
"""
...
...
deepspeech/io/sampler.py
浏览文件 @
38d95784
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
"""
rng
=
np
.
random
.
RandomState
(
epoch
)
rng
=
np
.
random
.
RandomState
(
epoch
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
rng
.
shuffle
(
batch_indices
)
rng
.
shuffle
(
batch_indices
)
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
assert
clipped
is
False
assert
clipped
is
False
...
...
deepspeech/models/ds2/deepspeech2.py
浏览文件 @
38d95784
...
@@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer):
...
@@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer):
layer_tools
.
summary
(
model
)
layer_tools
.
summary
(
model
)
return
model
return
model
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
use_gru
=
config
.
use_gru
,
share_rnn_weights
=
config
.
share_rnn_weights
)
return
model
class
DeepSpeech2InferModel
(
DeepSpeech2Model
):
class
DeepSpeech2InferModel
(
DeepSpeech2Model
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
...
@@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
eouts
,
eouts_len
=
self
.
encoder
(
audio
,
audio_len
)
eouts
,
eouts_len
=
self
.
encoder
(
audio
,
audio_len
)
probs
=
self
.
decoder
.
softmax
(
eouts
)
probs
=
self
.
decoder
.
softmax
(
eouts
)
return
probs
return
probs
def
export
(
self
):
static_model
=
paddle
.
jit
.
to_static
(
self
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
self
.
encoder
.
feat_size
],
dtype
=
'float32'
),
# audio, [B,T,D]
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
])
return
static_model
deepspeech/models/ds2_online/__init__.py
0 → 100644
浏览文件 @
38d95784
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.deepspeech2
import
DeepSpeech2InferModelOnline
from
.deepspeech2
import
DeepSpeech2ModelOnline
__all__
=
[
'DeepSpeech2ModelOnline'
,
'DeepSpeech2InferModelOnline'
]
deepspeech/models/ds2_online/conv.py
0 → 100644
浏览文件 @
38d95784
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
paddle
import
nn
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.subsampling
import
Conv2dSubsampling4
class
Conv2dSubsampling4Online
(
Conv2dSubsampling4
):
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
):
super
().
__init__
(
idim
,
odim
,
dropout_rate
,
None
)
self
.
output_dim
=
((
idim
-
1
)
//
2
-
1
)
//
2
*
odim
self
.
receptive_field_length
=
2
*
(
3
-
1
)
+
3
# stride_1 * (kernel_size_2 - 1) + kerel_size_1
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_len
:
paddle
.
Tensor
)
->
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
x
=
x
.
unsqueeze
(
1
)
# (b, c=1, t, f)
x
=
self
.
conv
(
x
)
#b, c, t, f = paddle.shape(x) #not work under jit
x
=
x
.
transpose
([
0
,
2
,
1
,
3
]).
reshape
([
0
,
0
,
-
1
])
x_len
=
((
x_len
-
1
)
//
2
-
1
)
//
2
return
x
,
x_len
deepspeech/models/ds2_online/deepspeech2.py
0 → 100644
浏览文件 @
38d95784
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
from
typing
import
Optional
import
paddle
import
paddle.nn.functional
as
F
from
paddle
import
nn
from
yacs.config
import
CfgNode
from
deepspeech.models.ds2_online.conv
import
Conv2dSubsampling4Online
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.utils
import
layer_tools
from
deepspeech.utils.checkpoint
import
Checkpoint
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
'DeepSpeech2ModelOnline'
,
'DeepSpeech2InferModeOnline'
]
class
CRNNEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
):
super
().
__init__
()
self
.
rnn_size
=
rnn_size
self
.
feat_size
=
feat_size
# 161 for linear
self
.
dict_size
=
dict_size
self
.
num_rnn_layers
=
num_rnn_layers
self
.
num_fc_layers
=
num_fc_layers
self
.
rnn_direction
=
rnn_direction
self
.
fc_layers_size_list
=
fc_layers_size_list
self
.
use_gru
=
use_gru
self
.
conv
=
Conv2dSubsampling4Online
(
feat_size
,
32
,
dropout_rate
=
0.0
)
self
.
output_dim
=
self
.
conv
.
output_dim
i_size
=
self
.
conv
.
output_dim
self
.
rnn
=
nn
.
LayerList
()
self
.
layernorm_list
=
nn
.
LayerList
()
self
.
fc_layers_list
=
nn
.
LayerList
()
if
rnn_direction
==
'bidirect'
or
rnn_direction
==
'bidirectional'
:
layernorm_size
=
2
*
rnn_size
elif
rnn_direction
==
'forward'
:
layernorm_size
=
rnn_size
else
:
raise
Exception
(
"Wrong rnn direction"
)
for
i
in
range
(
0
,
num_rnn_layers
):
if
i
==
0
:
rnn_input_size
=
i_size
else
:
rnn_input_size
=
layernorm_size
if
use_gru
==
True
:
self
.
rnn
.
append
(
nn
.
GRU
(
input_size
=
rnn_input_size
,
hidden_size
=
rnn_size
,
num_layers
=
1
,
direction
=
rnn_direction
))
else
:
self
.
rnn
.
append
(
nn
.
LSTM
(
input_size
=
rnn_input_size
,
hidden_size
=
rnn_size
,
num_layers
=
1
,
direction
=
rnn_direction
))
self
.
layernorm_list
.
append
(
nn
.
LayerNorm
(
layernorm_size
))
self
.
output_dim
=
layernorm_size
fc_input_size
=
layernorm_size
for
i
in
range
(
self
.
num_fc_layers
):
self
.
fc_layers_list
.
append
(
nn
.
Linear
(
fc_input_size
,
fc_layers_size_list
[
i
]))
fc_input_size
=
fc_layers_size_list
[
i
]
self
.
output_dim
=
fc_layers_size_list
[
i
]
@
property
def
output_size
(
self
):
return
self
.
output_dim
def
forward
(
self
,
x
,
x_lens
,
init_state_h_box
=
None
,
init_state_c_box
=
None
):
"""Compute Encoder outputs
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
if
init_state_h_box
is
not
None
:
init_state_list
=
None
if
self
.
use_gru
==
True
:
init_state_h_list
=
paddle
.
split
(
init_state_h_box
,
self
.
num_rnn_layers
,
axis
=
0
)
init_state_list
=
init_state_h_list
else
:
init_state_h_list
=
paddle
.
split
(
init_state_h_box
,
self
.
num_rnn_layers
,
axis
=
0
)
init_state_c_list
=
paddle
.
split
(
init_state_c_box
,
self
.
num_rnn_layers
,
axis
=
0
)
init_state_list
=
[(
init_state_h_list
[
i
],
init_state_c_list
[
i
])
for
i
in
range
(
self
.
num_rnn_layers
)]
else
:
init_state_list
=
[
None
]
*
self
.
num_rnn_layers
x
,
x_lens
=
self
.
conv
(
x
,
x_lens
)
final_chunk_state_list
=
[]
for
i
in
range
(
0
,
self
.
num_rnn_layers
):
x
,
final_state
=
self
.
rnn
[
i
](
x
,
init_state_list
[
i
],
x_lens
)
#[B, T, D]
final_chunk_state_list
.
append
(
final_state
)
x
=
self
.
layernorm_list
[
i
](
x
)
for
i
in
range
(
self
.
num_fc_layers
):
x
=
self
.
fc_layers_list
[
i
](
x
)
x
=
F
.
relu
(
x
)
if
self
.
use_gru
==
True
:
final_chunk_state_h_box
=
paddle
.
concat
(
final_chunk_state_list
,
axis
=
0
)
final_chunk_state_c_box
=
init_state_c_box
#paddle.zeros_like(final_chunk_state_h_box)
else
:
final_chunk_state_h_list
=
[
final_chunk_state_list
[
i
][
0
]
for
i
in
range
(
self
.
num_rnn_layers
)
]
final_chunk_state_c_list
=
[
final_chunk_state_list
[
i
][
1
]
for
i
in
range
(
self
.
num_rnn_layers
)
]
final_chunk_state_h_box
=
paddle
.
concat
(
final_chunk_state_h_list
,
axis
=
0
)
final_chunk_state_c_box
=
paddle
.
concat
(
final_chunk_state_c_list
,
axis
=
0
)
return
x
,
x_lens
,
final_chunk_state_h_box
,
final_chunk_state_c_box
def
forward_chunk_by_chunk
(
self
,
x
,
x_lens
,
decoder_chunk_size
=
8
):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
subsampling_rate
=
self
.
conv
.
subsampling_rate
receptive_field_length
=
self
.
conv
.
receptive_field_length
chunk_size
=
(
decoder_chunk_size
-
1
)
*
subsampling_rate
+
receptive_field_length
chunk_stride
=
subsampling_rate
*
decoder_chunk_size
max_len
=
x
.
shape
[
1
]
assert
(
chunk_size
<=
max_len
)
eouts_chunk_list
=
[]
eouts_chunk_lens_list
=
[]
padding_len
=
chunk_stride
-
(
max_len
-
chunk_size
)
%
chunk_stride
padding
=
paddle
.
zeros
((
x
.
shape
[
0
],
padding_len
,
x
.
shape
[
2
]))
padded_x
=
paddle
.
concat
([
x
,
padding
],
axis
=
1
)
num_chunk
=
(
max_len
+
padding_len
-
chunk_size
)
/
chunk_stride
+
1
num_chunk
=
int
(
num_chunk
)
chunk_state_h_box
=
None
chunk_state_c_box
=
None
final_state_h_box
=
None
final_state_c_box
=
None
for
i
in
range
(
0
,
num_chunk
):
start
=
i
*
chunk_stride
end
=
start
+
chunk_size
x_chunk
=
padded_x
[:,
start
:
end
,
:]
x_len_left
=
paddle
.
where
(
x_lens
-
i
*
chunk_stride
<
0
,
paddle
.
zeros_like
(
x_lens
),
x_lens
-
i
*
chunk_stride
)
x_chunk_len_tmp
=
paddle
.
ones_like
(
x_lens
)
*
chunk_size
x_chunk_lens
=
paddle
.
where
(
x_len_left
<
x_chunk_len_tmp
,
x_len_left
,
x_chunk_len_tmp
)
eouts_chunk
,
eouts_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
=
self
.
forward
(
x_chunk
,
x_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
)
eouts_chunk_list
.
append
(
eouts_chunk
)
eouts_chunk_lens_list
.
append
(
eouts_chunk_lens
)
final_state_h_box
=
chunk_state_h_box
final_state_c_box
=
chunk_state_c_box
return
eouts_chunk_list
,
eouts_chunk_lens_list
,
final_state_h_box
,
final_state_c_box
class
DeepSpeech2ModelOnline
(
nn
.
Layer
):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param num_fc_layers: Number of stacking FC layers.
:type num_fc_layers: int
:param fc_layers_size_list: The list of FC layer sizes.
:type fc_layers_size_list: [int,]
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
default
=
CfgNode
(
dict
(
num_conv_layers
=
2
,
#Number of stacking convolution layers.
num_rnn_layers
=
4
,
#Number of stacking RNN layers.
rnn_layer_size
=
1024
,
#RNN layer size (number of RNN cells).
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
,
#Use gru if set True. Use simple rnn if set False.
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
):
super
().
__init__
()
self
.
encoder
=
CRNNEncoder
(
feat_size
=
feat_size
,
dict_size
=
dict_size
,
num_conv_layers
=
num_conv_layers
,
num_rnn_layers
=
num_rnn_layers
,
rnn_direction
=
rnn_direction
,
num_fc_layers
=
num_fc_layers
,
fc_layers_size_list
=
fc_layers_size_list
,
rnn_size
=
rnn_size
,
use_gru
=
use_gru
)
self
.
decoder
=
CTCDecoder
(
odim
=
dict_size
,
# <blank> is in vocab
enc_n_units
=
self
.
encoder
.
output_size
,
blank_id
=
0
,
# first token is <blank>
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
)
# sum / batch_size
def
forward
(
self
,
audio
,
audio_len
,
text
,
text_len
):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts
,
eouts_len
,
final_state_h_box
,
final_state_c_box
=
self
.
encoder
(
audio
,
audio_len
,
None
,
None
)
loss
=
self
.
decoder
(
eouts
,
eouts_len
,
text
,
text_len
)
return
loss
@
paddle
.
no_grad
()
def
decode
(
self
,
audio
,
audio_len
,
vocab_list
,
decoding_method
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
num_processes
):
# init once
# decoders only accept string encoded in utf-8
self
.
decoder
.
init_decode
(
beam_alpha
=
beam_alpha
,
beam_beta
=
beam_beta
,
lang_model_path
=
lang_model_path
,
vocab_list
=
vocab_list
,
decoding_method
=
decoding_method
)
eouts
,
eouts_len
,
final_state_h_box
,
final_state_c_box
=
self
.
encoder
(
audio
,
audio_len
,
None
,
None
)
probs
=
self
.
decoder
.
softmax
(
eouts
)
return
self
.
decoder
.
decode_probs
(
probs
.
numpy
(),
eouts_len
,
vocab_list
,
decoding_method
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
num_processes
)
@
classmethod
def
from_pretrained
(
cls
,
dataloader
,
config
,
checkpoint_path
):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
dict_size
=
dataloader
.
collate_fn
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
rnn_direction
=
config
.
model
.
rnn_direction
,
num_fc_layers
=
config
.
model
.
num_fc_layers
,
fc_layers_size_list
=
config
.
model
.
fc_layers_size_list
,
use_gru
=
config
.
model
.
use_gru
)
infos
=
Checkpoint
().
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
layer_tools
.
summary
(
model
)
return
model
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a DeepSpeec2ModelOnline from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2ModelOnline
The model built from config.
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
rnn_direction
=
config
.
rnn_direction
,
num_fc_layers
=
config
.
num_fc_layers
,
fc_layers_size_list
=
config
.
fc_layers_size_list
,
use_gru
=
config
.
use_gru
)
return
model
class
DeepSpeech2InferModelOnline
(
DeepSpeech2ModelOnline
):
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
):
super
().
__init__
(
feat_size
=
feat_size
,
dict_size
=
dict_size
,
num_conv_layers
=
num_conv_layers
,
num_rnn_layers
=
num_rnn_layers
,
rnn_size
=
rnn_size
,
rnn_direction
=
rnn_direction
,
num_fc_layers
=
num_fc_layers
,
fc_layers_size_list
=
fc_layers_size_list
,
use_gru
=
use_gru
)
def
forward
(
self
,
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
):
eouts_chunk
,
eouts_chunk_lens
,
final_state_h_box
,
final_state_c_box
=
self
.
encoder
(
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
)
probs_chunk
=
self
.
decoder
.
softmax
(
eouts_chunk
)
return
probs_chunk
,
eouts_chunk_lens
,
final_state_h_box
,
final_state_c_box
def
export
(
self
):
static_model
=
paddle
.
jit
.
to_static
(
self
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
self
.
encoder
.
feat_size
],
#[B, chunk_size, feat_dim]
dtype
=
'float32'
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
None
],
dtype
=
'float32'
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
None
],
dtype
=
'float32'
)
])
return
static_model
deepspeech/modules/subsampling.py
浏览文件 @
38d95784
...
@@ -92,7 +92,7 @@ class Conv2dSubsampling4(BaseSubsampling):
...
@@ -92,7 +92,7 @@ class Conv2dSubsampling4(BaseSubsampling):
dropout_rate
:
float
,
dropout_rate
:
float
,
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
"""Construct an Conv2dSubsampling4 object.
"""Construct an Conv2dSubsampling4 object.
Args:
Args:
idim (int): Input dimension.
idim (int): Input dimension.
odim (int): Output dimension.
odim (int): Output dimension.
...
@@ -143,7 +143,7 @@ class Conv2dSubsampling6(BaseSubsampling):
...
@@ -143,7 +143,7 @@ class Conv2dSubsampling6(BaseSubsampling):
dropout_rate
:
float
,
dropout_rate
:
float
,
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
"""Construct an Conv2dSubsampling6 object.
"""Construct an Conv2dSubsampling6 object.
Args:
Args:
idim (int): Input dimension.
idim (int): Input dimension.
odim (int): Output dimension.
odim (int): Output dimension.
...
@@ -196,7 +196,7 @@ class Conv2dSubsampling8(BaseSubsampling):
...
@@ -196,7 +196,7 @@ class Conv2dSubsampling8(BaseSubsampling):
dropout_rate
:
float
,
dropout_rate
:
float
,
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
pos_enc_class
:
nn
.
Layer
=
PositionalEncoding
):
"""Construct an Conv2dSubsampling8 object.
"""Construct an Conv2dSubsampling8 object.
Args:
Args:
idim (int): Input dimension.
idim (int): Input dimension.
odim (int): Output dimension.
odim (int): Output dimension.
...
...
env.sh
浏览文件 @
38d95784
...
@@ -4,7 +4,7 @@ export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
...
@@ -4,7 +4,7 @@ export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export
LC_ALL
=
C
export
LC_ALL
=
C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export
PYTHONIOENCODING
=
UTF-8
export
PYTHONIOENCODING
=
UTF-8
export
PYTHONPATH
=
${
MAIN_ROOT
}
:
${
PYTHONPATH
}
export
PYTHONPATH
=
${
MAIN_ROOT
}
:
${
PYTHONPATH
}
export
LD_LIBRARY_PATH
=
${
LD_LIBRARY_PATH
}
:/usr/local/lib/
export
LD_LIBRARY_PATH
=
${
LD_LIBRARY_PATH
}
:/usr/local/lib/
examples/aishell/s0/conf/deepspeech2_online.yaml
0 → 100644
浏览文件 @
38d95784
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
min_input_len
:
0.0
max_input_len
:
27.0
# second
min_output_len
:
0.0
max_output_len
:
.inf
min_output_input_ratio
:
0.00
max_output_input_ratio
:
.inf
collator
:
batch_size
:
32
# one gpu
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
random_seed
:
0
spm_model_prefix
:
specgram_type
:
linear
#linear, mfcc, fbank
feat_dim
:
delta_delta
:
False
stride_ms
:
10.0
window_ms
:
20.0
n_fft
:
None
max_freq
:
None
target_sample_rate
:
16000
use_dB_normalization
:
True
target_dB
:
-20
dither
:
1.0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
0
model
:
num_conv_layers
:
2
num_rnn_layers
:
3
rnn_layer_size
:
1024
rnn_direction
:
forward
# [forward, bidirect]
num_fc_layers
:
1
fc_layers_size_list
:
512,
use_gru
:
True
training
:
n_epoch
:
50
lr
:
2e-3
lr_decay
:
0.83
# 0.83
weight_decay
:
1e-06
global_grad_clip
:
3.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
batch_size
:
32
error_rate_type
:
cer
decoding_method
:
ctc_beam_search
lang_model_path
:
data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha
:
1.9
beta
:
5.0
beam_size
:
300
cutoff_prob
:
0.99
cutoff_top_n
:
40
num_proc_bsearch
:
10
examples/aishell/s0/local/export.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
3
]
;
then
if
[
$#
!=
4
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
echo
"usage:
$0
config_path ckpt_prefix jit_model_path
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
config_path
=
$1
ckpt_path_prefix
=
$2
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
jit_model_export_path
=
$3
model_type
=
$4
device
=
gpu
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
if
[
${
ngpu
}
==
0
]
;
then
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--export_path
${
jit_model_export_path
}
--export_path
${
jit_model_export_path
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in export!"
echo
"Failed in export!"
...
...
examples/aishell/s0/local/test.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
echo
"usage:
${
0
}
config_path ckpt_path_prefix
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
fi
config_path
=
$1
config_path
=
$1
ckpt_prefix
=
$2
ckpt_prefix
=
$2
model_type
=
$3
# download language model
# download language model
bash
local
/download_lm_ch.sh
bash
local
/download_lm_ch.sh
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc
1
\
--nproc
1
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
echo
"Failed in evaluation!"
...
...
examples/aishell/s0/local/train.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
config_path
=
$1
ckpt_name
=
$2
ckpt_name
=
$2
model_type
=
$3
device
=
gpu
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
if
[
${
ngpu
}
==
0
]
;
then
...
@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
...
@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/aishell/s0/run.sh
浏览文件 @
38d95784
...
@@ -7,6 +7,7 @@ stage=0
...
@@ -7,6 +7,7 @@ stage=0
stop_stage
=
100
stop_stage
=
100
conf_path
=
conf/deepspeech2.yaml
conf_path
=
conf/deepspeech2.yaml
avg_num
=
1
avg_num
=
1
model_type
=
offline
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
@@ -21,7 +22,7 @@ fi
...
@@ -21,7 +22,7 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
${
model_type
}
fi
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
...
@@ -31,10 +32,10 @@ fi
...
@@ -31,10 +32,10 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
model_type
}
||
exit
-1
fi
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# export ckpt avg_n
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
CUDA_VISIBLE_DEVICES
=
0 ./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
fi
fi
examples/librispeech/s0/conf/deepspeech2_online.yaml
0 → 100644
浏览文件 @
38d95784
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev-clean
test_manifest
:
data/manifest.test-clean
min_input_len
:
0.0
max_input_len
:
27.0
# second
min_output_len
:
0.0
max_output_len
:
.inf
min_output_input_ratio
:
0.00
max_output_input_ratio
:
.inf
collator
:
batch_size
:
20
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
random_seed
:
0
spm_model_prefix
:
specgram_type
:
linear
target_sample_rate
:
16000
max_freq
:
None
n_fft
:
None
stride_ms
:
10.0
window_ms
:
20.0
delta_delta
:
False
dither
:
1.0
use_dB_normalization
:
True
target_dB
:
-20
random_seed
:
0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
0
model
:
num_conv_layers
:
2
num_rnn_layers
:
3
rnn_layer_size
:
2048
rnn_direction
:
forward
num_fc_layers
:
2
fc_layers_size_list
:
512,
256
use_gru
:
False
training
:
n_epoch
:
50
lr
:
1e-3
lr_decay
:
0.83
weight_decay
:
1e-06
global_grad_clip
:
5.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
batch_size
:
128
error_rate_type
:
wer
decoding_method
:
ctc_beam_search
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
alpha
:
1.9
beta
:
0.3
beam_size
:
500
cutoff_prob
:
1.0
cutoff_top_n
:
40
num_proc_bsearch
:
8
examples/librispeech/s0/local/export.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
3
]
;
then
if
[
$#
!=
4
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
echo
"usage:
$0
config_path ckpt_prefix jit_model_path
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
config_path
=
$1
ckpt_path_prefix
=
$2
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
jit_model_export_path
=
$3
model_type
=
$4
device
=
gpu
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
if
[
${
ngpu
}
==
0
]
;
then
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--export_path
${
jit_model_export_path
}
--export_path
${
jit_model_export_path
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in export!"
echo
"Failed in export!"
...
...
examples/librispeech/s0/local/test.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
echo
"usage:
${
0
}
config_path ckpt_path_prefix
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
fi
config_path
=
$1
config_path
=
$1
ckpt_prefix
=
$2
ckpt_prefix
=
$2
model_type
=
$3
# download language model
# download language model
bash
local
/download_lm_en.sh
bash
local
/download_lm_en.sh
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc
1
\
--nproc
1
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
echo
"Failed in evaluation!"
...
...
examples/librispeech/s0/local/train.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
config_path
=
$1
ckpt_name
=
$2
ckpt_name
=
$2
model_type
=
$3
device
=
gpu
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
if
[
${
ngpu
}
==
0
]
;
then
...
@@ -23,7 +24,8 @@ python3 -u ${BIN_DIR}/train.py \
...
@@ -23,7 +24,8 @@ python3 -u ${BIN_DIR}/train.py \
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/librispeech/s0/run.sh
浏览文件 @
38d95784
...
@@ -6,6 +6,7 @@ stage=0
...
@@ -6,6 +6,7 @@ stage=0
stop_stage
=
100
stop_stage
=
100
conf_path
=
conf/deepspeech2.yaml
conf_path
=
conf/deepspeech2.yaml
avg_num
=
30
avg_num
=
30
model_type
=
offline
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
avg_ckpt
=
avg_
${
avg_num
}
avg_ckpt
=
avg_
${
avg_num
}
...
@@ -19,7 +20,7 @@ fi
...
@@ -19,7 +20,7 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7 ./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7 ./local/train.sh
${
conf_path
}
${
ckpt
}
${
model_type
}
fi
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
...
@@ -29,10 +30,10 @@ fi
...
@@ -29,10 +30,10 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
7 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
7 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
model_type
}
||
exit
-1
fi
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# export ckpt avg_n
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
CUDA_VISIBLE_DEVICES
=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
fi
fi
examples/tiny/s0/conf/deepspeech2_online.yaml
0 → 100644
浏览文件 @
38d95784
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.tiny
dev_manifest
:
data/manifest.tiny
test_manifest
:
data/manifest.tiny
min_input_len
:
0.0
max_input_len
:
27.0
min_output_len
:
0.0
max_output_len
:
400.0
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
collator
:
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
random_seed
:
0
spm_model_prefix
:
specgram_type
:
linear
feat_dim
:
delta_delta
:
False
stride_ms
:
10.0
window_ms
:
20.0
n_fft
:
None
max_freq
:
None
target_sample_rate
:
16000
use_dB_normalization
:
True
target_dB
:
-20
dither
:
1.0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
0
batch_size
:
4
model
:
num_conv_layers
:
2
num_rnn_layers
:
4
rnn_layer_size
:
2048
rnn_direction
:
forward
num_fc_layers
:
2
fc_layers_size_list
:
512,
256
use_gru
:
True
training
:
n_epoch
:
10
lr
:
1e-5
lr_decay
:
1.0
weight_decay
:
1e-06
global_grad_clip
:
5.0
log_interval
:
1
checkpoint
:
kbest_n
:
3
latest_n
:
2
decoding
:
batch_size
:
128
error_rate_type
:
wer
decoding_method
:
ctc_beam_search
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
alpha
:
2.5
beta
:
0.3
beam_size
:
500
cutoff_prob
:
1.0
cutoff_top_n
:
40
num_proc_bsearch
:
8
examples/tiny/s0/local/export.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
3
]
;
then
if
[
$#
!=
4
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
echo
"usage:
$0
config_path ckpt_prefix jit_model_path
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
...
@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
config_path
=
$1
ckpt_path_prefix
=
$2
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
jit_model_export_path
=
$3
model_type
=
$4
device
=
gpu
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
if
[
${
ngpu
}
==
0
]
;
then
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
...
@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--export_path
${
jit_model_export_path
}
--export_path
${
jit_model_export_path
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in export!"
echo
"Failed in export!"
...
...
examples/tiny/s0/local/test.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
echo
"usage:
${
0
}
config_path ckpt_path_prefix
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
...
@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
fi
config_path
=
$1
config_path
=
$1
ckpt_prefix
=
$2
ckpt_prefix
=
$2
model_type
=
$3
# download language model
# download language model
bash
local
/download_lm_en.sh
bash
local
/download_lm_en.sh
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
...
@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc
1
\
--nproc
1
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
echo
"Failed in evaluation!"
...
...
examples/tiny/s0/local/train.sh
浏览文件 @
38d95784
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name
model_type
"
exit
-1
exit
-1
fi
fi
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
...
@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path
=
$1
config_path
=
$1
ckpt_name
=
$2
ckpt_name
=
$2
model_type
=
$3
device
=
gpu
device
=
gpu
if
[
${
ngpu
}
==
0
]
;
then
if
[
${
ngpu
}
==
0
]
;
then
...
@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
...
@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/tiny/s0/run.sh
浏览文件 @
38d95784
...
@@ -7,6 +7,7 @@ stage=0
...
@@ -7,6 +7,7 @@ stage=0
stop_stage
=
100
stop_stage
=
100
conf_path
=
conf/deepspeech2.yaml
conf_path
=
conf/deepspeech2.yaml
avg_num
=
1
avg_num
=
1
model_type
=
offline
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
@@ -21,7 +22,7 @@ fi
...
@@ -21,7 +22,7 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
${
model_type
}
fi
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
...
@@ -31,10 +32,10 @@ fi
...
@@ -31,10 +32,10 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
model_type
}
||
exit
-1
fi
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# export ckpt avg_n
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
fi
fi
tests/deepspeech2_model_test.py
浏览文件 @
38d95784
...
@@ -16,7 +16,7 @@ import unittest
...
@@ -16,7 +16,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
deepspeech.models.d
eepspeech
2
import
DeepSpeech2Model
from
deepspeech.models.d
s
2
import
DeepSpeech2Model
class
TestDeepSpeech2Model
(
unittest
.
TestCase
):
class
TestDeepSpeech2Model
(
unittest
.
TestCase
):
...
...
tests/deepspeech2_online_model_test.py
0 → 100644
浏览文件 @
38d95784
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
from
deepspeech.models.ds2_online
import
DeepSpeech2ModelOnline
class
TestDeepSpeech2ModelOnline
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
set_device
(
'cpu'
)
self
.
batch_size
=
2
self
.
feat_dim
=
161
max_len
=
210
# (B, T, D)
audio
=
np
.
random
.
randn
(
self
.
batch_size
,
max_len
,
self
.
feat_dim
)
audio_len
=
np
.
random
.
randint
(
max_len
,
size
=
self
.
batch_size
)
audio_len
[
-
1
]
=
max_len
# (B, U)
text
=
np
.
array
([[
1
,
2
],
[
1
,
2
]])
text_len
=
np
.
array
([
2
]
*
self
.
batch_size
)
self
.
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
self
.
audio_len
=
paddle
.
to_tensor
(
audio_len
,
dtype
=
'int64'
)
self
.
text
=
paddle
.
to_tensor
(
text
,
dtype
=
'int32'
)
self
.
text_len
=
paddle
.
to_tensor
(
text_len
,
dtype
=
'int64'
)
def
test_ds2_1
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_2
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_3
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_4
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_5
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_6
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
rnn_direction
=
'bidirect'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_7
(
self
):
use_gru
=
False
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
1
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
use_gru
)
model
.
eval
()
paddle
.
device
.
set_device
(
"cpu"
)
de_ch_size
=
8
eouts
,
eouts_lens
,
final_state_h_box
,
final_state_c_box
=
model
.
encoder
(
self
.
audio
,
self
.
audio_len
)
eouts_by_chk_list
,
eouts_lens_by_chk_list
,
final_state_h_box_chk
,
final_state_c_box_chk
=
model
.
encoder
.
forward_chunk_by_chunk
(
self
.
audio
,
self
.
audio_len
,
de_ch_size
)
eouts_by_chk
=
paddle
.
concat
(
eouts_by_chk_list
,
axis
=
1
)
eouts_lens_by_chk
=
paddle
.
add_n
(
eouts_lens_by_chk_list
)
decode_max_len
=
eouts
.
shape
[
1
]
eouts_by_chk
=
eouts_by_chk
[:,
:
decode_max_len
,
:]
self
.
assertEqual
(
paddle
.
allclose
(
eouts_by_chk
,
eouts
),
True
)
self
.
assertEqual
(
paddle
.
allclose
(
final_state_h_box
,
final_state_h_box_chk
),
True
)
if
use_gru
==
False
:
self
.
assertEqual
(
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
def
test_ds2_8
(
self
):
use_gru
=
True
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
1
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
use_gru
)
model
.
eval
()
paddle
.
device
.
set_device
(
"cpu"
)
de_ch_size
=
8
eouts
,
eouts_lens
,
final_state_h_box
,
final_state_c_box
=
model
.
encoder
(
self
.
audio
,
self
.
audio_len
)
eouts_by_chk_list
,
eouts_lens_by_chk_list
,
final_state_h_box_chk
,
final_state_c_box_chk
=
model
.
encoder
.
forward_chunk_by_chunk
(
self
.
audio
,
self
.
audio_len
,
de_ch_size
)
eouts_by_chk
=
paddle
.
concat
(
eouts_by_chk_list
,
axis
=
1
)
eouts_lens_by_chk
=
paddle
.
add_n
(
eouts_lens_by_chk_list
)
decode_max_len
=
eouts
.
shape
[
1
]
eouts_by_chk
=
eouts_by_chk
[:,
:
decode_max_len
,
:]
self
.
assertEqual
(
paddle
.
allclose
(
eouts_by_chk
,
eouts
),
True
)
self
.
assertEqual
(
paddle
.
allclose
(
final_state_h_box
,
final_state_h_box_chk
),
True
)
if
use_gru
==
False
:
self
.
assertEqual
(
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录