Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8e73d184
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8e73d184
编写于
10月 05, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
tiny/s0/s1 can run all
上级
6f83be1a
变更
41
展开全部
隐藏空白更改
内联
并排
Showing
41 changed file
with
1178 addition
and
545 deletion
+1178
-545
deepspeech/exps/deepspeech2/bin/deploy/runtime.py
deepspeech/exps/deepspeech2/bin/deploy/runtime.py
+16
-9
deepspeech/exps/deepspeech2/bin/deploy/server.py
deepspeech/exps/deepspeech2/bin/deploy/server.py
+20
-9
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+328
-104
deepspeech/exps/u2/bin/alignment.py
deepspeech/exps/u2/bin/alignment.py
+3
-0
deepspeech/exps/u2/bin/export.py
deepspeech/exps/u2/bin/export.py
+3
-0
deepspeech/exps/u2/bin/test.py
deepspeech/exps/u2/bin/test.py
+3
-0
deepspeech/exps/u2/bin/train.py
deepspeech/exps/u2/bin/train.py
+3
-1
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+45
-45
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+61
-51
deepspeech/frontend/utility.py
deepspeech/frontend/utility.py
+122
-37
deepspeech/io/collator.py
deepspeech/io/collator.py
+9
-3
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+2
-1
deepspeech/models/ds2/conv.py
deepspeech/models/ds2/conv.py
+7
-7
deepspeech/models/ds2/deepspeech2.py
deepspeech/models/ds2/deepspeech2.py
+7
-21
deepspeech/models/ds2/rnn.py
deepspeech/models/ds2/rnn.py
+6
-6
deepspeech/models/ds2_online/deepspeech2.py
deepspeech/models/ds2_online/deepspeech2.py
+41
-55
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+116
-60
deepspeech/utils/log.py
deepspeech/utils/log.py
+3
-3
examples/dataset/mini_librispeech/.gitignore
examples/dataset/mini_librispeech/.gitignore
+1
-0
examples/dataset/mini_librispeech/mini_librispeech.py
examples/dataset/mini_librispeech/mini_librispeech.py
+21
-0
examples/librispeech/s1/local/align.sh
examples/librispeech/s1/local/align.sh
+32
-0
examples/librispeech/s1/local/data.sh
examples/librispeech/s1/local/data.sh
+1
-1
examples/librispeech/s1/local/download_lm_en.sh
examples/librispeech/s1/local/download_lm_en.sh
+1
-1
examples/librispeech/s1/local/export.sh
examples/librispeech/s1/local/export.sh
+1
-7
examples/librispeech/s1/local/test.sh
examples/librispeech/s1/local/test.sh
+67
-39
examples/librispeech/s1/local/train.sh
examples/librispeech/s1/local/train.sh
+17
-8
examples/tiny/s0/conf/deepspeech2.yaml
examples/tiny/s0/conf/deepspeech2.yaml
+3
-0
examples/tiny/s0/conf/deepspeech2_online.yaml
examples/tiny/s0/conf/deepspeech2_online.yaml
+72
-0
examples/tiny/s0/local/download_lm_en.sh
examples/tiny/s0/local/download_lm_en.sh
+6
-1
examples/tiny/s0/local/export.sh
examples/tiny/s0/local/export.sh
+6
-11
examples/tiny/s0/local/test.sh
examples/tiny/s0/local/test.sh
+7
-10
examples/tiny/s0/local/train.sh
examples/tiny/s0/local/train.sh
+26
-12
examples/tiny/s0/path.sh
examples/tiny/s0/path.sh
+1
-1
examples/tiny/s0/run.sh
examples/tiny/s0/run.sh
+6
-5
examples/tiny/s1/conf/transformer.yaml
examples/tiny/s1/conf/transformer.yaml
+2
-0
examples/tiny/s1/local/align.sh
examples/tiny/s1/local/align.sh
+32
-0
examples/tiny/s1/local/data.sh
examples/tiny/s1/local/data.sh
+1
-1
examples/tiny/s1/local/export.sh
examples/tiny/s1/local/export.sh
+1
-7
examples/tiny/s1/local/test.sh
examples/tiny/s1/local/test.sh
+42
-15
examples/tiny/s1/local/train.sh
examples/tiny/s1/local/train.sh
+28
-11
examples/tiny/s1/run.sh
examples/tiny/s1/run.sh
+9
-3
未找到文件。
deepspeech/exps/deepspeech2/bin/deploy/runtime.py
浏览文件 @
8e73d184
...
...
@@ -18,8 +18,10 @@ import numpy as np
import
paddle
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddle.io
import
DataLoader
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.io.collator
import
SpeechCollator
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.models.ds2
import
DeepSpeech2Model
from
deepspeech.training.cli
import
default_argument_parser
...
...
@@ -78,26 +80,31 @@ def inference(config, args):
def
start_server
(
config
,
args
):
"""Start the ASR server"""
config
.
defrost
()
config
.
data
.
manfiest
=
config
.
data
.
test_manifest
config
.
data
.
augmentation_config
=
""
config
.
data
.
keep_transcription_text
=
True
config
.
data
.
manifest
=
config
.
data
.
test_manifest
dataset
=
ManifestDataset
.
from_config
(
config
)
model
=
DeepSpeech2Model
.
from_pretrained
(
dataset
,
config
,
config
.
collator
.
augmentation_config
=
""
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
batch_size
=
1
config
.
collator
.
num_workers
=
0
collate_fn
=
SpeechCollator
.
from_config
(
config
)
test_loader
=
DataLoader
(
dataset
,
collate_fn
=
collate_fn
,
num_workers
=
0
)
model
=
DeepSpeech2Model
.
from_pretrained
(
test_loader
,
config
,
args
.
checkpoint_path
)
model
.
eval
()
# prepare ASR inference handler
def
file_to_transcript
(
filename
):
feature
=
dataset
.
process_utterance
(
filename
,
""
)
audio
=
np
.
array
([
feature
[
0
]]).
astype
(
'float32'
)
#[1,
D, T
]
audio_len
=
feature
[
0
].
shape
[
1
]
feature
=
test_loader
.
collate_fn
.
process_utterance
(
filename
,
""
)
audio
=
np
.
array
([
feature
[
0
]]).
astype
(
'float32'
)
#[1,
T, D
]
audio_len
=
feature
[
0
].
shape
[
0
]
audio_len
=
np
.
array
([
audio_len
]).
astype
(
'int64'
)
# [1]
result_transcript
=
model
.
decode
(
paddle
.
to_tensor
(
audio
),
paddle
.
to_tensor
(
audio_len
),
vocab_list
=
dataset
.
vocab_list
,
vocab_list
=
test_loader
.
collate_fn
.
vocab_list
,
decoding_method
=
config
.
decoding
.
decoding_method
,
lang_model_path
=
config
.
decoding
.
lang_model_path
,
beam_alpha
=
config
.
decoding
.
alpha
,
...
...
@@ -138,7 +145,7 @@ if __name__ == "__main__":
add_arg
(
'host_ip'
,
str
,
'localhost'
,
"Server's IP address."
)
add_arg
(
'host_port'
,
int
,
808
6
,
"Server's IP port."
)
add_arg
(
'host_port'
,
int
,
808
9
,
"Server's IP port."
)
add_arg
(
'speech_save_dir'
,
str
,
'demo_cache'
,
"Directory to save demo audios."
)
...
...
deepspeech/exps/deepspeech2/bin/deploy/server.py
浏览文件 @
8e73d184
...
...
@@ -16,8 +16,10 @@ import functools
import
numpy
as
np
import
paddle
from
paddle.io
import
DataLoader
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.io.collator
import
SpeechCollator
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.models.ds2
import
DeepSpeech2Model
from
deepspeech.training.cli
import
default_argument_parser
...
...
@@ -31,26 +33,35 @@ from deepspeech.utils.utility import print_arguments
def
start_server
(
config
,
args
):
"""Start the ASR server"""
config
.
defrost
()
config
.
data
.
manfiest
=
config
.
data
.
test_manifest
config
.
data
.
augmentation_config
=
""
config
.
data
.
keep_transcription_text
=
True
config
.
data
.
manifest
=
config
.
data
.
test_manifest
dataset
=
ManifestDataset
.
from_config
(
config
)
model
=
DeepSpeech2Model
.
from_pretrained
(
dataset
,
config
,
config
.
collator
.
augmentation_config
=
""
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
batch_size
=
1
config
.
collator
.
num_workers
=
0
collate_fn
=
SpeechCollator
.
from_config
(
config
)
test_loader
=
DataLoader
(
dataset
,
collate_fn
=
collate_fn
,
num_workers
=
0
)
model
=
DeepSpeech2Model
.
from_pretrained
(
test_loader
,
config
,
args
.
checkpoint_path
)
model
.
eval
()
# prepare ASR inference handler
def
file_to_transcript
(
filename
):
feature
=
dataset
.
process_utterance
(
filename
,
""
)
audio
=
np
.
array
([
feature
[
0
]]).
astype
(
'float32'
)
#[1, D, T]
audio_len
=
feature
[
0
].
shape
[
1
]
feature
=
test_loader
.
collate_fn
.
process_utterance
(
filename
,
""
)
audio
=
np
.
array
([
feature
[
0
]]).
astype
(
'float32'
)
#[1, T, D]
# audio = audio.swapaxes(1,2)
print
(
'---file_to_transcript feature----'
)
print
(
audio
.
shape
)
audio_len
=
feature
[
0
].
shape
[
0
]
print
(
audio_len
)
audio_len
=
np
.
array
([
audio_len
]).
astype
(
'int64'
)
# [1]
result_transcript
=
model
.
decode
(
paddle
.
to_tensor
(
audio
),
paddle
.
to_tensor
(
audio_len
),
vocab_list
=
dataset
.
vocab_list
,
vocab_list
=
test_loader
.
collate_fn
.
vocab_list
,
decoding_method
=
config
.
decoding
.
decoding_method
,
lang_model_path
=
config
.
decoding
.
lang_model_path
,
beam_alpha
=
config
.
decoding
.
alpha
,
...
...
@@ -91,7 +102,7 @@ if __name__ == "__main__":
add_arg
(
'host_ip'
,
str
,
'localhost'
,
"Server's IP address."
)
add_arg
(
'host_port'
,
int
,
808
6
,
"Server's IP port."
)
add_arg
(
'host_port'
,
int
,
808
8
,
"Server's IP port."
)
add_arg
(
'speech_save_dir'
,
str
,
'demo_cache'
,
"Directory to save demo audios."
)
...
...
deepspeech/exps/deepspeech2/model.py
浏览文件 @
8e73d184
此差异已折叠。
点击以展开。
deepspeech/exps/u2/bin/alignment.py
浏览文件 @
8e73d184
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/export.py
浏览文件 @
8e73d184
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/test.py
浏览文件 @
8e73d184
...
...
@@ -34,6 +34,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/train.py
浏览文件 @
8e73d184
...
...
@@ -22,6 +22,8 @@ from deepspeech.exps.u2.model import U2Trainer as Trainer
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
# from deepspeech.exps.u2.trainer import U2Trainer as Trainer
def
main_sp
(
config
,
args
):
exp
=
Trainer
(
config
,
args
)
...
...
@@ -30,7 +32,7 @@ def main_sp(config, args):
def
main
(
config
,
args
):
if
args
.
device
==
"gpu"
and
args
.
nprocs
>
1
:
if
args
.
nprocs
>
0
:
dist
.
spawn
(
main_sp
,
args
=
(
config
,
args
),
nprocs
=
args
.
nprocs
)
else
:
main_sp
(
config
,
args
)
...
...
deepspeech/exps/u2/model.py
浏览文件 @
8e73d184
...
...
@@ -73,11 +73,11 @@ class U2Trainer(Trainer):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
def
train_batch
(
self
,
batch_index
,
batch
_data
,
msg
):
def
train_batch
(
self
,
batch_index
,
batch
,
msg
):
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch
_data
)
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch
)
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
loss
.
backward
()
...
...
@@ -219,7 +219,7 @@ class U2Trainer(Trainer):
config
.
data
.
augmentation_config
=
""
dev_dataset
=
ManifestDataset
.
from_config
(
config
)
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
False
)
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
False
,
return_utts
=
False
)
if
self
.
parallel
:
batch_sampler
=
SortagradDistributedBatchSampler
(
train_dataset
,
...
...
@@ -269,7 +269,7 @@ class U2Trainer(Trainer):
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
True
))
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
True
,
return_utts
=
True
))
logger
.
info
(
"Setup train/valid/test Dataloader!"
)
def
setup_model
(
self
):
...
...
@@ -345,7 +345,7 @@ class U2Tester(U2Trainer):
decoding_chunk_size
=-
1
,
# decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks
=-
1
,
# number of left chunks for decoding. Defaults to -1.
simulate_streaming
=
False
,
# simulate streaming inference. Defaults to False.
))
...
...
@@ -428,7 +428,7 @@ class U2Tester(U2Trainer):
num_time
=
0.0
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
for
i
,
batch
in
enumerate
(
self
.
test_loader
):
metrics
=
self
.
compute_metrics
(
*
batch
,
fout
=
fout
)
metrics
=
self
.
compute_metrics
(
*
batch
[:
-
1
]
,
fout
=
fout
)
num_frames
+=
metrics
[
'num_frames'
]
num_time
+=
metrics
[
"decode_time"
]
errors_sum
+=
metrics
[
'errors_sum'
]
...
...
@@ -476,12 +476,12 @@ class U2Tester(U2Trainer):
})
f
.
write
(
data
+
'
\n
'
)
def
run_test
(
self
):
self
.
resume_or_scratch
()
try
:
self
.
test
()
except
KeyboardInterrupt
:
sys
.
exit
(
-
1
)
#
def run_test(self):
#
self.resume_or_scratch()
#
try:
#
self.test()
#
except KeyboardInterrupt:
#
sys.exit(-1)
def
load_inferspec
(
self
):
"""infer model and input spec.
...
...
@@ -512,36 +512,36 @@ class U2Tester(U2Trainer):
logger
.
info
(
f
"Export code:
{
static_model
.
forward
.
code
}
"
)
paddle
.
jit
.
save
(
static_model
,
self
.
args
.
export_path
)
def
run_export
(
self
):
try
:
self
.
export
()
except
KeyboardInterrupt
:
sys
.
exit
(
-
1
)
def
setup
(
self
):
"""Setup the experiment.
"""
paddle
.
set_device
(
self
.
args
.
device
)
self
.
setup_output_dir
()
self
.
setup_checkpointer
()
self
.
setup_dataloader
()
self
.
setup_model
()
self
.
iteration
=
0
self
.
epoch
=
0
def
setup_output_dir
(
self
):
"""Create a directory used for output.
"""
# output dir
if
self
.
args
.
output
:
output_dir
=
Path
(
self
.
args
.
output
).
expanduser
()
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
else
:
output_dir
=
Path
(
self
.
args
.
checkpoint_path
).
expanduser
().
parent
.
parent
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
output_dir
=
output_dir
#
def run_export(self):
#
try:
#
self.export()
#
except KeyboardInterrupt:
#
sys.exit(-1)
#
def setup(self):
#
"""Setup the experiment.
#
"""
#
paddle.set_device(self.args.device)
#
self.setup_output_dir()
#
self.setup_checkpointer()
#
self.setup_dataloader()
#
self.setup_model()
#
self.iteration = 0
#
self.epoch = 0
#
def setup_output_dir(self):
#
"""Create a directory used for output.
#
"""
#
# output dir
#
if self.args.output:
#
output_dir = Path(self.args.output).expanduser()
#
output_dir.mkdir(parents=True, exist_ok=True)
#
else:
#
output_dir = Path(
#
self.args.checkpoint_path).expanduser().parent.parent
#
output_dir.mkdir(parents=True, exist_ok=True)
#
self.output_dir = output_dir
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
8e73d184
...
...
@@ -14,12 +14,27 @@
"""Contains the text featurizer class."""
import
sentencepiece
as
spm
from
deepspeech.frontend.utility
import
EOS
from
deepspeech.frontend.utility
import
UNK
from
..utility
import
EOS
from
..utility
import
SPACE
from
..utility
import
UNK
from
..utility
import
SOS
from
..utility
import
BLANK
from
..utility
import
MASKCTC
from
..utility
import
load_dict
from
deepspeech.utils.log
import
Log
class
TextFeaturizer
(
object
):
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
):
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"TextFeaturizer"
]
class
TextFeaturizer
():
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
,
maskctc
=
False
):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
...
...
@@ -34,20 +49,21 @@ class TextFeaturizer(object):
assert
unit_type
in
(
'char'
,
'spm'
,
'word'
)
self
.
unit_type
=
unit_type
self
.
unk
=
UNK
self
.
maskctc
=
maskctc
if
vocab_filepath
:
self
.
_vocab_dict
,
self
.
_id2token
,
self
.
_vocab_list
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
)
self
.
unk_id
=
self
.
_vocab_list
.
index
(
self
.
unk
)
self
.
eos_id
=
self
.
_vocab_list
.
index
(
EOS
)
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
self
.
unk_id
,
self
.
eos_id
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
,
maskctc
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
if
unit_type
==
'spm'
:
spm_model
=
spm_model_prefix
+
'.model'
self
.
sp
=
spm
.
SentencePieceProcessor
()
self
.
sp
.
Load
(
spm_model
)
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
,
replace_space
=
True
):
if
self
.
unit_type
==
'char'
:
tokens
=
self
.
char_tokenize
(
text
)
tokens
=
self
.
char_tokenize
(
text
,
replace_space
)
elif
self
.
unit_type
==
'word'
:
tokens
=
self
.
word_tokenize
(
text
)
else
:
# spm
...
...
@@ -67,27 +83,27 @@ class TextFeaturizer(object):
"""Convert text string to a list of token indices.
Args:
text (str): Text
to process
.
text (str): Text.
Returns:
List[int]: List of token indices.
"""
tokens
=
self
.
tokenize
(
text
)
ids
=
[]
for
token
in
tokens
:
token
=
token
if
token
in
self
.
_
vocab_dict
else
self
.
unk
ids
.
append
(
self
.
_
vocab_dict
[
token
])
token
=
token
if
token
in
self
.
vocab_dict
else
self
.
unk
ids
.
append
(
self
.
vocab_dict
[
token
])
return
ids
def
defeaturize
(
self
,
idxs
):
"""Convert a list of token indices to text string,
ignore index after eos_id.
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text
to process
.
str: Text.
"""
tokens
=
[]
for
idx
in
idxs
:
...
...
@@ -97,43 +113,22 @@ class TextFeaturizer(object):
text
=
self
.
detokenize
(
tokens
)
return
text
@
property
def
vocab_size
(
self
):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return
len
(
self
.
_vocab_list
)
@
property
def
vocab_list
(
self
):
"""Return the vocabulary in list.
Returns:
List[str]: tokens.
"""
return
self
.
_vocab_list
@
property
def
vocab_dict
(
self
):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]: token str -> int
"""
return
self
.
_vocab_dict
def
char_tokenize
(
self
,
text
):
def
char_tokenize
(
self
,
text
,
replace_space
=
True
):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
return
list
(
text
.
strip
())
text
=
text
.
strip
()
if
replace_space
:
text_list
=
[
SPACE
if
item
==
" "
else
item
for
item
in
list
(
text
)]
else
:
text_list
=
list
(
text
)
return
text_list
def
char_detokenize
(
self
,
tokens
):
"""Character detokenizer.
...
...
@@ -144,6 +139,7 @@ class TextFeaturizer(object):
Returns:
str: text string.
"""
tokens
=
tokens
.
replace
(
SPACE
,
" "
)
return
""
.
join
(
tokens
)
def
word_tokenize
(
self
,
text
):
...
...
@@ -206,14 +202,28 @@ class TextFeaturizer(object):
return
decode
(
tokens
)
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
):
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
:
str
,
maskctc
:
bool
):
"""Load vocabulary from file."""
vocab_li
nes
=
[]
with
open
(
vocab_filepath
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
vocab_lines
.
extend
(
file
.
readlines
()
)
vocab_list
=
[
line
[:
-
1
]
for
line
in
vocab_lines
]
vocab_li
st
=
load_dict
(
vocab_filepath
,
maskctc
)
assert
vocab_list
is
not
None
logger
.
info
(
f
"Vocab:
{
vocab_list
}
"
)
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
token2id
=
dict
(
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
return
token2id
,
id2token
,
vocab_list
blank_id
=
vocab_list
.
index
(
BLANK
)
if
BLANK
in
vocab_list
else
-
1
maskctc_id
=
vocab_list
.
index
(
MASKCTC
)
if
MASKCTC
in
vocab_list
else
-
1
unk_id
=
vocab_list
.
index
(
UNK
)
if
UNK
in
vocab_list
else
-
1
eos_id
=
vocab_list
.
index
(
EOS
)
if
EOS
in
vocab_list
else
-
1
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
logger
.
info
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
info
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
deepspeech/frontend/utility.py
浏览文件 @
8e73d184
...
...
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains data helper functions."""
import
codecs
import
json
import
math
import
tarfile
from
collections
import
namedtuple
from
typing
import
List
from
typing
import
Optional
from
typing
import
Text
import
jsonlines
import
numpy
as
np
from
deepspeech.utils.log
import
Log
...
...
@@ -23,16 +28,40 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"load_
cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to_dbfs"
,
"max
_dbfs"
,
"m
ean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS"
,
"EOS"
,
"UNK
"
,
"
BLANK
"
"load_
dict"
,
"load_cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to
_dbfs"
,
"m
ax_dbfs"
,
"mean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS
"
,
"
EOS"
,
"UNK"
,
"BLANK"
,
"MASKCTC"
,
"SPACE
"
]
IGNORE_ID
=
-
1
SOS
=
"<sos/eos>"
# `sos` and `eos` using same token
SOS
=
"<eos>"
EOS
=
SOS
UNK
=
"<unk>"
BLANK
=
"<blank>"
MASKCTC
=
"<mask>"
SPACE
=
"<space>"
def
load_dict
(
dict_path
:
Optional
[
Text
],
maskctc
=
False
)
->
Optional
[
List
[
Text
]]:
if
dict_path
is
None
:
return
None
with
open
(
dict_path
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
# first token is `<blank>`
# multi line: `<blank> 0\n`
# one line: `<blank>`
# space is relpace with <space>
char_list
=
[
entry
[:
-
1
].
split
(
" "
)[
0
]
for
entry
in
dictionary
]
if
BLANK
not
in
char_list
:
char_list
.
insert
(
0
,
BLANK
)
if
EOS
not
in
char_list
:
char_list
.
append
(
EOS
)
# for non-autoregressive maskctc model
if
maskctc
and
MASKCTC
not
in
char_list
:
char_list
.
append
(
MASKCTC
)
return
char_list
def
read_manifest
(
...
...
@@ -47,12 +76,20 @@ def read_manifest(
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
...
...
@@ -62,29 +99,70 @@ def read_manifest(
"""
manifest
=
[]
for
json_line
in
codecs
.
open
(
manifest_path
,
'r'
,
'utf-8'
):
try
:
json_data
=
json
.
loads
(
json_line
)
except
Exception
as
e
:
raise
IOError
(
"Error reading manifest: %s"
%
str
(
e
))
feat_len
=
json_data
[
"feat_shape"
][
0
]
if
'feat_shape'
in
json_data
else
1.0
token_len
=
json_data
[
"token_shape"
][
0
]
if
'token_shape'
in
json_data
else
1.0
conditions
=
[
feat_len
>=
min_input_len
,
feat_len
<=
max_input_len
,
token_len
>=
min_output_len
,
token_len
<=
max_output_len
,
token_len
/
feat_len
>=
min_output_input_ratio
,
token_len
/
feat_len
<=
max_output_input_ratio
,
]
if
all
(
conditions
):
manifest
.
append
(
json_data
)
with
jsonlines
.
open
(
manifest_path
,
'r'
)
as
reader
:
for
json_data
in
reader
:
feat_len
=
json_data
[
"feat_shape"
][
0
]
if
'feat_shape'
in
json_data
else
1.0
token_len
=
json_data
[
"token_shape"
][
0
]
if
'token_shape'
in
json_data
else
1.0
conditions
=
[
feat_len
>=
min_input_len
,
feat_len
<=
max_input_len
,
token_len
>=
min_output_len
,
token_len
<=
max_output_len
,
token_len
/
feat_len
>=
min_output_input_ratio
,
token_len
/
feat_len
<=
max_output_input_ratio
,
]
if
all
(
conditions
):
manifest
.
append
(
json_data
)
return
manifest
# Tar File read
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
def
parse_tar
(
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
subfile_from_tar
(
file
,
local_data
=
None
):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
local_data
is
None
:
local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
assert
isinstance
(
local_data
,
TarLocalData
)
if
'tar2info'
not
in
local_data
.
__dict__
:
local_data
.
tar2info
=
{}
if
'tar2object'
not
in
local_data
.
__dict__
:
local_data
.
tar2object
=
{}
if
tarpath
not
in
local_data
.
tar2info
:
fobj
,
infos
=
parse_tar
(
tarpath
)
local_data
.
tar2info
[
tarpath
]
=
infos
local_data
.
tar2object
[
tarpath
]
=
fobj
else
:
fobj
=
local_data
.
tar2object
[
tarpath
]
infos
=
local_data
.
tar2info
[
tarpath
]
return
fobj
.
extractfile
(
infos
[
filename
])
def
rms_to_db
(
rms
:
float
):
"""Root Mean Square to dB.
...
...
@@ -101,7 +179,7 @@ def rms_to_dbfs(rms: float):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
...
...
@@ -116,26 +194,26 @@ def rms_to_dbfs(rms: float):
def
max_dbfs
(
sample_data
:
np
.
ndarray
):
"""Peak dBFS based on the maximum energy sample.
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return
rms_to_dbfs
(
max
(
abs
(
np
.
min
(
sample_data
)),
abs
(
np
.
max
(
sample_data
))))
def
mean_dbfs
(
sample_data
):
"""Peak dBFS based on the RMS energy.
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
float: dBFS
"""
return
rms_to_dbfs
(
math
.
sqrt
(
np
.
mean
(
np
.
square
(
sample_data
,
dtype
=
np
.
float64
))))
...
...
@@ -155,7 +233,7 @@ def gain_db_to_ratio(gain_db: float):
def
normalize_audio
(
sample_data
:
np
.
ndarray
,
dbfs
:
float
=-
3.0103
):
"""Nomalize audio to dBFS.
Args:
sample_data (np.ndarray): input wave samples, [-1, 1].
dbfs (float, optional): target dBFS. Defaults to -3.0103.
...
...
@@ -254,6 +332,13 @@ def load_cmvn(cmvn_file: str, filetype: str):
cmvn
=
_load_json_cmvn
(
cmvn_file
)
elif
filetype
==
"kaldi"
:
cmvn
=
_load_kaldi_cmvn
(
cmvn_file
)
elif
filetype
==
"npz"
:
eps
=
1e-14
npzfile
=
np
.
load
(
cmvn_file
)
mean
=
np
.
squeeze
(
npzfile
[
"mean"
])
std
=
np
.
squeeze
(
npzfile
[
"std"
])
istd
=
1
/
(
std
+
eps
)
cmvn
=
[
mean
,
istd
]
else
:
raise
ValueError
(
f
"cmvn file type no support:
{
filetype
}
"
)
return
cmvn
[
0
],
cmvn
[
1
]
deepspeech/io/collator.py
浏览文件 @
8e73d184
...
...
@@ -23,7 +23,7 @@ logger = Log(__name__).getlog()
class
SpeechCollator
():
def
__init__
(
self
,
keep_transcription_text
=
True
):
def
__init__
(
self
,
keep_transcription_text
=
True
,
return_utts
=
False
):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
...
...
@@ -31,6 +31,7 @@ class SpeechCollator():
if ``keep_transcription_text`` is False, text is token ids else is raw string.
"""
self
.
_keep_transcription_text
=
keep_transcription_text
self
.
return_utts
=
return_utts
def
__call__
(
self
,
batch
):
"""batch examples
...
...
@@ -51,7 +52,9 @@ class SpeechCollator():
audio_lens
=
[]
texts
=
[]
text_lens
=
[]
for
audio
,
text
in
batch
:
utts
=
[]
for
utt
,
audio
,
text
in
batch
:
utts
.
append
(
utt
)
# audio
audios
.
append
(
audio
.
T
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
1
])
...
...
@@ -75,4 +78,7 @@ class SpeechCollator():
padded_texts
=
pad_sequence
(
texts
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
text_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
if
self
.
return_utts
:
return
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
,
utts
else
:
return
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
\ No newline at end of file
deepspeech/io/dataset.py
浏览文件 @
8e73d184
...
...
@@ -347,4 +347,5 @@ class ManifestDataset(Dataset):
def
__getitem__
(
self
,
idx
):
instance
=
self
.
_manifest
[
idx
]
return
self
.
process_utterance
(
instance
[
"feat"
],
instance
[
"text"
])
feat
,
text
=
self
.
process_utterance
(
instance
[
"feat"
],
instance
[
"text"
])
return
instance
[
"utt"
],
feat
,
text
deepspeech/models/ds2/conv.py
浏览文件 @
8e73d184
...
...
@@ -26,9 +26,9 @@ __all__ = ['ConvStack', "conv_output_size"]
def
conv_output_size
(
I
,
F
,
P
,
S
):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
...
...
@@ -45,7 +45,7 @@ def conv_output_size(I, F, P, S):
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
# Rl-1 = Sl * Rl + (Kl - Sl)
class
ConvBn
(
nn
.
Layer
):
...
...
@@ -58,8 +58,8 @@ class ConvBn(nn.Layer):
:type num_channels_in: int
:param num_channels_out: Number of output channels.
:type num_channels_out: int
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
...
...
@@ -114,7 +114,7 @@ class ConvBn(nn.Layer):
masks
=
make_non_pad_mask
(
x_len
)
#[B, T]
masks
=
masks
.
unsqueeze
(
1
).
unsqueeze
(
1
)
# [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
masks
=
masks
.
type_as
(
x
)
masks
=
masks
.
astype
(
x
.
dtype
)
x
=
x
.
multiply
(
masks
)
return
x
,
x_len
...
...
deepspeech/models/ds2/deepspeech2.py
浏览文件 @
8e73d184
...
...
@@ -219,15 +219,17 @@ class DeepSpeech2Model(nn.Layer):
The model built from pretrained result.
"""
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
dict_size
=
dataloader
.
collate_fn
.
vocab_size
,
#feat_size=dataloader.collate_fn.feature_size,
feat_size
=
dataloader
.
dataset
.
feature_size
,
#dict_size=dataloader.collate_fn.vocab_size,
dict_size
=
dataloader
.
dataset
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
use_gru
=
config
.
model
.
use_gru
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
,
blank_id
=
config
.
model
.
blank_id
,
ctc_grad_norm_type
=
config
.
ctc_grad_norm_type
,
)
ctc_grad_norm_type
=
config
.
model
.
ctc_grad_norm_type
,
)
infos
=
Checkpoint
().
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
...
...
@@ -260,24 +262,8 @@ class DeepSpeech2Model(nn.Layer):
class
DeepSpeech2InferModel
(
DeepSpeech2Model
):
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
use_gru
=
False
,
share_rnn_weights
=
True
,
blank_id
=
0
):
super
().
__init__
(
feat_size
=
feat_size
,
dict_size
=
dict_size
,
num_conv_layers
=
num_conv_layers
,
num_rnn_layers
=
num_rnn_layers
,
rnn_size
=
rnn_size
,
use_gru
=
use_gru
,
share_rnn_weights
=
share_rnn_weights
,
blank_id
=
blank_id
)
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
audio
,
audio_len
):
"""export model function
...
...
deepspeech/models/ds2/rnn.py
浏览文件 @
8e73d184
...
...
@@ -29,13 +29,13 @@ __all__ = ['RNNStack']
class
RNNCell
(
nn
.
RNNCellBase
):
r
"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
...
...
@@ -92,7 +92,7 @@ class RNNCell(nn.RNNCellBase):
class
GRUCell
(
nn
.
RNNCellBase
):
r
"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
...
...
@@ -101,8 +101,8 @@ class GRUCell(nn.RNNCellBase):
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
...
...
@@ -309,6 +309,6 @@ class RNNStack(nn.Layer):
masks
=
make_non_pad_mask
(
x_len
)
#[B, T]
masks
=
masks
.
unsqueeze
(
-
1
)
# [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks
=
masks
.
type_as
(
x
)
masks
=
masks
.
astype
(
x
.
dtype
)
x
=
x
.
multiply
(
masks
)
return
x
,
x_len
deepspeech/models/ds2_online/deepspeech2.py
浏览文件 @
8e73d184
...
...
@@ -255,22 +255,24 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
,
#Use gru if set True. Use simple rnn if set False.
blank_id
=
0
,
# index of blank in vocob.txt
))
ctc_grad_norm_type
=
'instance'
,
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
,
blank_id
=
0
):
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
,
blank_id
=
0
,
ctc_grad_norm_type
=
'instance'
,
):
super
().
__init__
()
self
.
encoder
=
CRNNEncoder
(
feat_size
=
feat_size
,
...
...
@@ -290,7 +292,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
'instance'
)
grad_norm_type
=
ctc_grad_norm_type
)
def
forward
(
self
,
audio
,
audio_len
,
text
,
text_len
):
"""Compute Model loss
...
...
@@ -348,16 +350,18 @@ class DeepSpeech2ModelOnline(nn.Layer):
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
dict_size
=
dataloader
.
collate_fn
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
rnn_direction
=
config
.
model
.
rnn_direction
,
num_fc_layers
=
config
.
model
.
num_fc_layers
,
fc_layers_size_list
=
config
.
model
.
fc_layers_size_list
,
use_gru
=
config
.
model
.
use_gru
,
blank_id
=
config
.
model
.
blank_id
)
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
dict_size
=
dataloader
.
collate_fn
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
rnn_direction
=
config
.
model
.
rnn_direction
,
num_fc_layers
=
config
.
model
.
num_fc_layers
,
fc_layers_size_list
=
config
.
model
.
fc_layers_size_list
,
use_gru
=
config
.
model
.
use_gru
,
blank_id
=
config
.
model
.
blank_id
,
ctc_grad_norm_type
=
config
.
model
.
ctc_grad_norm_type
,
)
infos
=
Checkpoint
().
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
...
...
@@ -376,42 +380,24 @@ class DeepSpeech2ModelOnline(nn.Layer):
DeepSpeech2ModelOnline
The model built from config.
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
rnn_direction
=
config
.
rnn_direction
,
num_fc_layers
=
config
.
num_fc_layers
,
fc_layers_size_list
=
config
.
fc_layers_size_list
,
use_gru
=
config
.
use_gru
,
blank_id
=
config
.
blank_id
)
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
rnn_direction
=
config
.
rnn_direction
,
num_fc_layers
=
config
.
num_fc_layers
,
fc_layers_size_list
=
config
.
fc_layers_size_list
,
use_gru
=
config
.
use_gru
,
blank_id
=
config
.
blank_id
,
ctc_grad_norm_type
=
config
.
ctc_grad_norm_type
,
)
return
model
class
DeepSpeech2InferModelOnline
(
DeepSpeech2ModelOnline
):
def
__init__
(
self
,
feat_size
,
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
,
blank_id
=
0
):
super
().
__init__
(
feat_size
=
feat_size
,
dict_size
=
dict_size
,
num_conv_layers
=
num_conv_layers
,
num_rnn_layers
=
num_rnn_layers
,
rnn_size
=
rnn_size
,
rnn_direction
=
rnn_direction
,
num_fc_layers
=
num_fc_layers
,
fc_layers_size_list
=
fc_layers_size_list
,
use_gru
=
use_gru
,
blank_id
=
blank_id
)
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
):
...
...
deepspeech/training/trainer.py
浏览文件 @
8e73d184
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
from
contextlib
import
contextmanager
from
pathlib
import
Path
import
paddle
...
...
@@ -29,37 +30,37 @@ logger = Log(__name__).getlog()
class
Trainer
():
"""
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
visualizer and logger) in a standard way without enforcing any
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
customize all the text and visual logs).
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
non-standard behaviors if needed.
We have some conventions to follow.
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
``valid_loader``, ``config`` and ``args`` attributes.
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
experiment.
3. There are four methods, namely ``train_batch``, ``valid``,
3. There are four methods, namely ``train_batch``, ``valid``,
``setup_model`` and ``setup_dataloader`` that should be implemented.
Feel free to add/overwrite other methods and standalone functions if you
Feel free to add/overwrite other methods and standalone functions if you
need.
Parameters
----------
config: yacs.config.CfgNode
The configuration used for the experiment.
args: argparse.Namespace
The parsed command line arguments.
Examples
...
...
@@ -68,17 +69,17 @@ class Trainer():
>>> exp = Trainer(config, args)
>>> exp.setup()
>>> exp.run()
>>>
>>>
>>> config = get_cfg_defaults()
>>> parser = default_argument_parser()
>>> args = parser.parse_args()
>>> if args.config:
>>> if args.config:
>>> config.merge_from_file(args.config)
>>> if args.opts:
>>> config.merge_from_list(args.opts)
>>> config.freeze()
>>>
>>> if args.nprocs >
1 and args.device == "gpu"
:
>>>
>>> if args.nprocs >
0
:
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
>>> else:
>>> main_sp(config, args)
...
...
@@ -93,18 +94,24 @@ class Trainer():
self
.
checkpoint_dir
=
None
self
.
iteration
=
0
self
.
epoch
=
0
self
.
_train
=
True
def
setup
(
self
):
"""Setup the experiment.
"""
paddle
.
set_device
(
self
.
args
.
device
)
paddle
.
set_device
(
'gpu'
if
self
.
args
.
nprocs
>
0
else
'cpu'
)
if
self
.
parallel
:
self
.
init_parallel
()
@
contextmanager
def
eval
(
self
):
self
.
_train
=
False
yield
self
.
_train
=
True
def
setup
(
self
):
"""Setup the experiment.
"""
self
.
setup_output_dir
()
self
.
dump_config
()
self
.
setup_visualizer
()
self
.
setup_checkpointer
()
self
.
setup_dataloader
()
self
.
setup_model
()
...
...
@@ -114,10 +121,10 @@ class Trainer():
@
property
def
parallel
(
self
):
"""A flag indicating whether the experiment should run with
"""A flag indicating whether the experiment should run with
multiprocessing.
"""
return
self
.
args
.
device
==
"gpu"
and
self
.
args
.
nprocs
>
1
return
self
.
args
.
nprocs
>
1
def
init_parallel
(
self
):
"""Init environment for multiprocess training.
...
...
@@ -144,9 +151,9 @@ class Trainer():
self
.
optimizer
,
infos
)
def
resume_or_scratch
(
self
):
"""Resume from latest checkpoint at checkpoints in the output
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
...
...
@@ -158,8 +165,8 @@ class Trainer():
checkpoint_path
=
self
.
args
.
checkpoint_path
)
if
infos
:
# restore from ckpt
self
.
iteration
=
infos
[
"step"
]
self
.
epoch
=
infos
[
"epoch"
]
self
.
iteration
=
infos
[
"step"
]
+
1
self
.
epoch
=
infos
[
"epoch"
]
+
1
scratch
=
False
else
:
self
.
iteration
=
0
...
...
@@ -237,31 +244,61 @@ class Trainer():
try
:
self
.
train
()
except
KeyboardInterrupt
:
self
.
save
()
exit
(
-
1
)
finally
:
self
.
destory
()
logger
.
info
(
"Training Done."
)
logger
.
info
(
"Train Done."
)
def
run_test
(
self
):
"""Do Test/Decode"""
with
self
.
eval
():
self
.
resume_or_scratch
()
try
:
self
.
test
()
except
KeyboardInterrupt
:
exit
(
-
1
)
logger
.
info
(
"Test/Decode Done."
)
def
run_export
(
self
):
"""Do Model Export"""
with
self
.
eval
():
try
:
self
.
export
()
except
KeyboardInterrupt
:
exit
(
-
1
)
logger
.
info
(
"Export Done."
)
def
setup_output_dir
(
self
):
"""Create a directory used for output.
"""
# output dir
output_dir
=
Path
(
self
.
args
.
output
).
expanduser
()
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
self
.
args
.
output
:
output_dir
=
Path
(
self
.
args
.
output
).
expanduser
()
elif
self
.
args
.
checkpoint_path
:
output_dir
=
Path
(
self
.
args
.
checkpoint_path
).
expanduser
().
parent
.
parent
self
.
output_dir
=
output_dir
self
.
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
def
setup_checkpointer
(
self
):
"""Create a directory used to save checkpoints into.
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir
=
self
.
output_dir
/
"checkpoints"
checkpoint_dir
.
mkdir
(
exist_ok
=
True
)
self
.
checkpoint_dir
=
self
.
output_dir
/
"checkpoints"
self
.
checkpoint_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
log_dir
=
output_dir
/
"log"
self
.
log_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
test_dir
=
output_dir
/
"test"
self
.
test_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
decode_dir
=
output_dir
/
"decode"
self
.
decode_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
checkpoint_dir
=
checkpoint_dir
self
.
export_dir
=
output_dir
/
"export"
self
.
export_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
visual_dir
=
output_dir
/
"visual"
self
.
visual_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
config_dir
=
output_dir
/
"conf"
self
.
config_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
@
mp_tools
.
rank_zero_only
def
destory
(
self
):
...
...
@@ -273,27 +310,34 @@ class Trainer():
@
mp_tools
.
rank_zero_only
def
setup_visualizer
(
self
):
"""Initialize a visualizer to log the experiment.
The visual log is saved in the output directory.
Notes
------
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
unexpected behaviors.
"""
# visualizer
visualizer
=
SummaryWriter
(
logdir
=
str
(
self
.
output
_dir
))
visualizer
=
SummaryWriter
(
logdir
=
str
(
self
.
visual
_dir
))
self
.
visualizer
=
visualizer
@
mp_tools
.
rank_zero_only
def
dump_config
(
self
):
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment.
"""
with
open
(
self
.
output_dir
/
"config.yaml"
,
'wt'
)
as
f
:
config_file
=
self
.
config_dir
/
"config.yaml"
if
self
.
_train
and
config_file
.
exists
():
time_stamp
=
time
.
strftime
(
"%Y_%m_%d_%H_%M_%s"
,
time
.
gmtime
())
target_path
=
self
.
config_dir
/
"."
.
join
(
[
time_stamp
,
"config.yaml"
])
config_file
.
rename
(
target_path
)
with
open
(
config_file
,
'wt'
)
as
f
:
print
(
self
.
config
,
file
=
f
)
def
train_batch
(
self
):
...
...
@@ -307,14 +351,26 @@ class Trainer():
"""
raise
NotImplementedError
(
"valid should be implemented."
)
@
paddle
.
no_grad
()
def
test
(
self
):
"""The test. A subclass should implement this method in Tester.
"""
raise
NotImplementedError
(
"test should be implemented."
)
@
paddle
.
no_grad
()
def
export
(
self
):
"""The test. A subclass should implement this method in Tester.
"""
raise
NotImplementedError
(
"export should be implemented."
)
def
setup_model
(
self
):
"""Setup model, criterion and optimizer, etc. A subclass should
"""Setup model, criterion and optimizer, etc. A subclass should
implement this method.
"""
raise
NotImplementedError
(
"setup_model should be implemented."
)
def
setup_dataloader
(
self
):
"""Setup training dataloader and validation dataloader. A subclass
"""Setup training dataloader and validation dataloader. A subclass
should implement this method.
"""
raise
NotImplementedError
(
"setup_dataloader should be implemented."
)
deepspeech/utils/log.py
浏览文件 @
8e73d184
...
...
@@ -120,14 +120,15 @@ class Autolog:
model_precision
=
"fp32"
):
import
auto_log
pid
=
os
.
getpid
()
if
(
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
].
strip
()
!=
''
):
if
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
):
gpu_id
=
int
(
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
].
split
(
','
)[
0
])
infer_config
=
inference
.
Config
()
infer_config
.
enable_use_gpu
(
100
,
gpu_id
)
else
:
gpu_id
=
None
infer_config
=
inference
.
Config
()
autolog
=
auto_log
.
AutoLogger
(
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
model_name
,
model_precision
=
model_precision
,
batch_size
=
batch_size
,
...
...
@@ -139,7 +140,6 @@ class Autolog:
gpu_ids
=
gpu_id
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
0
)
self
.
autolog
=
autolog
def
getlog
(
self
):
return
self
.
autolog
examples/dataset/mini_librispeech/.gitignore
浏览文件 @
8e73d184
...
...
@@ -2,3 +2,4 @@ dev-clean/
manifest.dev-clean
manifest.train-clean
train-clean/
*.meta
examples/dataset/mini_librispeech/mini_librispeech.py
浏览文件 @
8e73d184
...
...
@@ -58,6 +58,10 @@ def create_manifest(data_dir, manifest_path):
"""
print
(
"Creating manifest %s ..."
%
manifest_path
)
json_lines
=
[]
total_sec
=
0.0
total_text
=
0.0
total_num
=
0
for
subfolder
,
_
,
filelist
in
sorted
(
os
.
walk
(
data_dir
)):
text_filelist
=
[
filename
for
filename
in
filelist
if
filename
.
endswith
(
'trans.txt'
)
...
...
@@ -80,10 +84,27 @@ def create_manifest(data_dir, manifest_path):
'text'
:
text
}))
total_sec
+=
duration
total_text
+=
len
(
text
)
total_num
+=
1
with
codecs
.
open
(
manifest_path
,
'w'
,
'utf-8'
)
as
out_file
:
for
line
in
json_lines
:
out_file
.
write
(
line
+
'
\n
'
)
subset
=
os
.
path
.
splitext
(
manifest_path
)[
1
][
1
:]
manifest_dir
=
os
.
path
.
dirname
(
manifest_path
)
data_dir_name
=
os
.
path
.
split
(
data_dir
)[
-
1
]
meta_path
=
os
.
path
.
join
(
manifest_dir
,
data_dir_name
)
+
'.meta'
with
open
(
meta_path
,
'w'
)
as
f
:
print
(
f
"
{
subset
}
:"
,
file
=
f
)
print
(
f
"
{
total_num
}
utts"
,
file
=
f
)
print
(
f
"
{
total_sec
/
(
60
*
60
)
}
h"
,
file
=
f
)
print
(
f
"
{
total_text
}
text"
,
file
=
f
)
print
(
f
"
{
total_text
/
total_sec
}
text/sec"
,
file
=
f
)
print
(
f
"
{
total_sec
/
total_num
}
sec/utt"
,
file
=
f
)
def
prepare_dataset
(
url
,
md5sum
,
target_dir
,
manifest_path
):
"""Download, unpack and create summmary manifest file.
...
...
examples/librispeech/s1/local/align.sh
0 → 100755
浏览文件 @
8e73d184
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
exit
-1
fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
config_path
=
$1
ckpt_prefix
=
$2
batch_size
=
1
output_dir
=
${
ckpt_prefix
}
mkdir
-p
${
output_dir
}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3
-u
${
BIN_DIR
}
/alignment.py
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--result_file
${
output_dir
}
/
${
type
}
.align
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.batch_size
${
batch_size
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in ctc alignment!"
exit
1
fi
exit
0
examples/librispeech/s1/local/data.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
stage
=
-1
stop_stage
=
100
...
...
examples/librispeech/s1/local/download_lm_en.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
.
${
MAIN_ROOT
}
/utils/utility.sh
...
...
examples/librispeech/s1/local/export.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
if
[
$#
!=
3
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
...
...
@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
fi
python3
-u
${
BIN_DIR
}
/export.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
...
...
examples/librispeech/s1/local/test.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
set
-e
expdir
=
exp
datadir
=
data
nj
=
32
lmtag
=
recog_set
=
"test-clean test-other dev-clean dev-other"
recog_set
=
"test-clean"
# bpemode (unigram or bpe)
nbpe
=
5000
bpemode
=
unigram
bpeprefix
=
"data/bpe_
${
bpemode
}
_
${
nbpe
}
"
bpemodel
=
${
bpeprefix
}
.model
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path dict_path ckpt_path_prefix"
exit
-1
fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
fi
config_path
=
$1
ckpt_prefix
=
$2
dict
=
$2
ckpt_prefix
=
$3
chunk_mode
=
false
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
chunk_mode
=
true
fi
echo
"chunk mode
${
chunk_mode
}
"
# download language model
#bash local/download_lm_en.sh
...
...
@@ -21,39 +42,46 @@ ckpt_prefix=$2
# exit 1
#fi
for
type
in
attention ctc_greedy_search
;
do
echo
"decoding
${
type
}
"
batch_size
=
64
python3
-u
${
BIN_DIR
}
/test.py
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
exit
1
fi
done
pids
=()
# initialize pids
for
dmethd
in
attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring
;
do
(
for
rtask
in
${
recog_set
}
;
do
(
decode_dir
=
decode_
${
rtask
}
_
${
dmethd
}
_
$(
basename
${
config_path
%.*
}
)
_
${
lmtag
}
feat_recog_dir
=
${
datadir
}
mkdir
-p
${
expdir
}
/
${
decode_dir
}
mkdir
-p
${
feat_recog_dir
}
# split data
split_json.sh
${
feat_recog_dir
}
/manifest.
${
rtask
}
${
nj
}
#### use CPU for decoding
ngpu
=
0
# set batchsize 0 to disable batch decoding
batch_size
=
1
${
decode_cmd
}
JOB
=
1:
${
nj
}
${
expdir
}
/
${
decode_dir
}
/log/decode.JOB.log
\
python3
-u
${
BIN_DIR
}
/test.py
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--result_file
${
expdir
}
/
${
decode_dir
}
/data.JOB.json
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
dmethd
}
\
--opts
decoding.batch_size
${
batch_size
}
\
--opts
data.test_manifest
${
feat_recog_dir
}
/split
${
nj
}
/JOB/manifest.
${
rtask
}
score_sclite.sh
--bpe
${
nbpe
}
--bpemodel
${
bpemodel
}
.model
--wer
true
${
expdir
}
/
${
decode_dir
}
${
dict
}
for
type
in
ctc_prefix_beam_search attention_rescoring
;
do
echo
"decoding
${
type
}
"
batch_size
=
1
python3
-u
${
BIN_DIR
}
/test.py
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
exit
1
fi
)
&
pids+
=(
$!
)
# store background pids
done
)
&
pids+
=(
$!
)
# store background pids
done
i
=
0
;
for
pid
in
"
${
pids
[@]
}
"
;
do
wait
${
pid
}
||
((
++i
))
;
done
[
${
i
}
-gt
0
]
&&
echo
"
$0
:
${
i
}
background jobs are failed."
&&
false
echo
"Finished"
exit
0
examples/librispeech/s1/local/train.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
if
[
$#
!=
2
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
...
...
@@ -11,19 +11,28 @@ echo "using $ngpu gpus..."
config_path
=
$1
ckpt_name
=
$2
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
mkdir
-p
exp
# seed may break model convergence
seed
=
0
if
[
${
seed
}
!=
0
]
;
then
#export FLAGS_cudnn_deterministic=True
echo
"None"
fi
echo
"using
${
device
}
..."
mkdir
-p
exp
# export FLAGS_cudnn_exhaustive_search=true
# export FLAGS_conv_workspace_size_limit=4000
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
!=
0
]
;
then
#unset FLAGS_cudnn_deterministic
echo
"None"
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/tiny/s0/conf/deepspeech2.yaml
浏览文件 @
8e73d184
...
...
@@ -4,6 +4,7 @@ data:
dev_manifest
:
data/manifest.tiny
test_manifest
:
data/manifest.tiny
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
batch_size
:
4
...
...
@@ -35,6 +36,8 @@ model:
rnn_layer_size
:
2048
use_gru
:
False
share_rnn_weights
:
True
blank_id
:
0
ctc_grad_norm_type
:
instance
training
:
n_epoch
:
20
...
...
examples/tiny/s0/conf/deepspeech2_online.yaml
0 → 100644
浏览文件 @
8e73d184
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.tiny
dev_manifest
:
data/manifest.tiny
test_manifest
:
data/manifest.tiny
min_input_len
:
0.0
max_input_len
:
30.0
min_output_len
:
0.0
max_output_len
:
400.0
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
collator
:
mean_std_filepath
:
data/mean_std.json
unit_type
:
char
vocab_filepath
:
data/vocab.txt
augmentation_config
:
conf/augmentation.json
random_seed
:
0
spm_model_prefix
:
spectrum_type
:
linear
feat_dim
:
delta_delta
:
False
stride_ms
:
10.0
window_ms
:
20.0
n_fft
:
None
max_freq
:
None
target_sample_rate
:
16000
use_dB_normalization
:
True
target_dB
:
-20
dither
:
1.0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
0
batch_size
:
4
model
:
num_conv_layers
:
2
num_rnn_layers
:
4
rnn_layer_size
:
2048
rnn_direction
:
forward
num_fc_layers
:
2
fc_layers_size_list
:
512,
256
use_gru
:
True
blank_id
:
0
ctc_grad_norm_type
:
instance
training
:
n_epoch
:
10
accum_grad
:
1
lr
:
1e-5
lr_decay
:
1.0
weight_decay
:
1e-06
global_grad_clip
:
5.0
log_interval
:
1
checkpoint
:
kbest_n
:
3
latest_n
:
2
decoding
:
batch_size
:
128
error_rate_type
:
wer
decoding_method
:
ctc_beam_search
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
alpha
:
2.5
beta
:
0.3
beam_size
:
500
cutoff_prob
:
1.0
cutoff_top_n
:
40
num_proc_bsearch
:
8
examples/tiny/s0/local/download_lm_en.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
.
${
MAIN_ROOT
}
/utils/utility.sh
...
...
@@ -9,6 +9,11 @@ URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5
=
"099a601759d467cd0a8523ff939819c5"
TARGET
=
${
DIR
}
/common_crawl_00.prune01111.trie.klm
if
[
-e
$TARGET
]
;
then
echo
"
$TARGET
exists."
exit
0
fi
echo
"Download language model ..."
download
$URL
$MD5
$TARGET
if
[
$?
-ne
0
]
;
then
...
...
examples/tiny/s0/local/export.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
if
[
$#
!=
3
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
if
[
$#
!=
4
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path
model_type
"
exit
-1
fi
...
...
@@ -11,19 +11,14 @@ echo "using $ngpu gpus..."
config_path
=
$1
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
fi
model_type
=
$4
python3
-u
${
BIN_DIR
}
/export.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
--export_path
${
jit_model_export_path
}
--export_path
${
jit_model_export_path
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in export!"
...
...
examples/tiny/s0/local/test.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix
model_type
"
exit
-1
fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
fi
config_path
=
$1
ckpt_prefix
=
$2
model_type
=
$3
# download language model
bash
local
/download_lm_en.sh
...
...
@@ -22,11 +19,11 @@ if [ $? -ne 0 ]; then
fi
python3
-u
${
BIN_DIR
}
/test.py
\
--device
${
device
}
\
--nproc
1
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
--checkpoint_path
${
ckpt_prefix
}
\
--model_type
${
model_type
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
...
...
examples/tiny/s0/local/train.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
if
[
$#
!=
2
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
exit
-1
fi
profiler_options
=
# seed may break model convergence
seed
=
0
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
config_path
=
$1
ckpt_name
=
$2
if
[
${
seed
}
!=
0
]
;
then
export
FLAGS_cudnn_deterministic
=
True
echo
"using seed
$seed
& FLAGS_cudnn_deterministic=True ..."
fi
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
if
[
$#
!=
3
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name model_type"
exit
-1
fi
config_path
=
$1
ckpt_name
=
$2
model_type
=
$3
mkdir
-p
exp
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
\
--profiler-options
"
${
profiler_options
}
"
\
--seed
${
seed
}
if
[
${
seed
}
!=
0
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/tiny/s0/path.sh
浏览文件 @
8e73d184
export
MAIN_ROOT
=
${
PWD
}
/../../../
export
MAIN_ROOT
=
`
realpath
${
PWD
}
/../../../
`
export
PATH
=
${
MAIN_ROOT
}
:
${
MAIN_ROOT
}
/utils:
${
PATH
}
export
LC_ALL
=
C
...
...
examples/tiny/s0/run.sh
浏览文件 @
8e73d184
...
...
@@ -7,11 +7,12 @@ stage=0
stop_stage
=
100
conf_path
=
conf/deepspeech2.yaml
avg_num
=
1
model_type
=
offline
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
avg_ckpt
=
avg_
${
avg_num
}
ckpt
=
$(
basename
${
conf_path
}
|
awk
-F
'.'
'{print $1}'
)
ckpt
=
$(
basename
${
conf_path
}
|
awk
-F
'.'
'{print $1}'
)
###ckpt = deepspeech2
echo
"checkpoint name
${
ckpt
}
"
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
...
...
@@ -21,20 +22,20 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
${
model_type
}
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# avg n best model
./local/avg.sh
exp/
${
ckpt
}
/checkpoints
${
avg_num
}
avg.sh best
exp/
${
ckpt
}
/checkpoints
${
avg_num
}
fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
model_type
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
${
model_type
}
fi
examples/tiny/s1/conf/transformer.yaml
浏览文件 @
8e73d184
...
...
@@ -65,6 +65,8 @@ model:
# hybrid CTC/attention
model_conf
:
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
instance
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
...
...
examples/tiny/s1/local/align.sh
0 → 100755
浏览文件 @
8e73d184
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
exit
-1
fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
config_path
=
$1
ckpt_prefix
=
$2
batch_size
=
1
output_dir
=
${
ckpt_prefix
}
mkdir
-p
${
output_dir
}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3
-u
${
BIN_DIR
}
/alignment.py
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--result_file
${
output_dir
}
/
${
type
}
.align
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.batch_size
${
batch_size
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in ctc alignment!"
exit
1
fi
exit
0
examples/tiny/s1/local/data.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
stage
=
-1
stop_stage
=
100
...
...
examples/tiny/s1/local/export.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
if
[
$#
!=
3
]
;
then
echo
"usage:
$0
config_path ckpt_prefix jit_model_path"
...
...
@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix
=
$2
jit_model_export_path
=
$3
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
fi
python3
-u
${
BIN_DIR
}
/export.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--checkpoint_path
${
ckpt_path_prefix
}
\
...
...
examples/tiny/s1/local/test.sh
浏览文件 @
8e73d184
#!
/usr/bin/env
bash
#!
/bin/
bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
...
...
@@ -8,30 +8,57 @@ fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
fi
config_path
=
$1
ckpt_prefix
=
$2
chunk_mode
=
false
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
chunk_mode
=
true
fi
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
python3
-u
${
BIN_DIR
}
/test.py
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
for
type
in
attention ctc_greedy_search
;
do
echo
"decoding
${
type
}
"
if
[
${
chunk_mode
}
==
true
]
;
then
# stream decoding only support batchsize=1
batch_size
=
1
else
batch_size
=
64
fi
python3
-u
${
BIN_DIR
}
/test.py
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
\
--opts
decoding.batch_size
${
batch_size
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
exit
1
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
exit
1
fi
done
for
type
in
ctc_prefix_beam_search attention_rescoring
;
do
echo
"decoding
${
type
}
"
batch_size
=
1
python3
-u
${
BIN_DIR
}
/test.py
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--result_file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
\
--opts
decoding.batch_size
${
batch_size
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
exit
1
fi
done
exit
0
examples/tiny/s1/local/train.sh
浏览文件 @
8e73d184
#! /usr/bin/env bash
#!/bin/bash
profiler_options
=
benchmark_batch_size
=
0
benchmark_max_step
=
0
# seed may break model convergence
seed
=
0
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
if
[
${
seed
}
!=
0
]
;
then
export
FLAGS_cudnn_deterministic
=
True
echo
"using seed
$seed
& FLAGS_cudnn_deterministic=True ..."
fi
if
[
$#
!=
2
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name"
exit
-1
fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
config_path
=
$1
ckpt_name
=
$2
device
=
gpu
if
[
ngpu
==
0
]
;
then
device
=
cpu
fi
mkdir
-p
exp
python3
-u
${
BIN_DIR
}
/train.py
\
--
device
${
device
}
\
--
seed
${
seed
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--profiler-options
"
${
profiler_options
}
"
\
--benchmark-batch-size
${
benchmark_batch_size
}
\
--benchmark-max-step
${
benchmark_max_step
}
if
[
${
seed
}
!=
0
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/tiny/s1/run.sh
浏览文件 @
8e73d184
...
...
@@ -20,20 +20,26 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
4,5,6,7
./local/train.sh
${
conf_path
}
${
ckpt
}
./local/train.sh
${
conf_path
}
${
ckpt
}
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# avg n best model
./local/avg.sh
exp/
${
ckpt
}
/checkpoints
${
avg_num
}
avg.sh best
exp/
${
ckpt
}
/checkpoints
${
avg_num
}
fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
7
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES
=
./local/align.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
fi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录