Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
561d5cf0
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
561d5cf0
编写于
8月 23, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor feature, dict and argument for new config format
上级
27daa92a
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
158 addition
and
100 deletion
+158
-100
.flake8
.flake8
+4
-0
deepspeech/exps/deepspeech2/bin/export.py
deepspeech/exps/deepspeech2/bin/export.py
+3
-0
deepspeech/exps/deepspeech2/bin/test.py
deepspeech/exps/deepspeech2/bin/test.py
+3
-0
deepspeech/exps/u2/bin/alignment.py
deepspeech/exps/u2/bin/alignment.py
+3
-0
deepspeech/exps/u2/bin/export.py
deepspeech/exps/u2/bin/export.py
+3
-0
deepspeech/exps/u2/bin/test.py
deepspeech/exps/u2/bin/test.py
+3
-0
deepspeech/exps/u2_kaldi/bin/test.py
deepspeech/exps/u2_kaldi/bin/test.py
+9
-0
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+22
-10
deepspeech/exps/u2_st/bin/export.py
deepspeech/exps/u2_st/bin/export.py
+3
-0
deepspeech/exps/u2_st/bin/test.py
deepspeech/exps/u2_st/bin/test.py
+3
-0
deepspeech/frontend/featurizer/__init__.py
deepspeech/frontend/featurizer/__init__.py
+3
-0
deepspeech/frontend/featurizer/audio_featurizer.py
deepspeech/frontend/featurizer/audio_featurizer.py
+1
-1
deepspeech/frontend/featurizer/speech_featurizer.py
deepspeech/frontend/featurizer/speech_featurizer.py
+1
-1
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+28
-45
deepspeech/frontend/utility.py
deepspeech/frontend/utility.py
+40
-10
deepspeech/training/cli.py
deepspeech/training/cli.py
+0
-7
examples/aishell/s0/conf/augmentation.json
examples/aishell/s0/conf/augmentation.json
+1
-1
examples/librispeech/s2/conf/transformer.yaml
examples/librispeech/s2/conf/transformer.yaml
+2
-8
examples/librispeech/s2/local/align.sh
examples/librispeech/s2/local/align.sh
+8
-5
examples/librispeech/s2/local/export.sh
examples/librispeech/s2/local/export.sh
+2
-1
examples/librispeech/s2/local/test.sh
examples/librispeech/s2/local/test.sh
+12
-7
examples/librispeech/s2/run.sh
examples/librispeech/s2/run.sh
+3
-2
examples/tiny/s0/conf/augmentation.json
examples/tiny/s0/conf/augmentation.json
+1
-2
未找到文件。
.flake8
浏览文件 @
561d5cf0
...
...
@@ -42,6 +42,10 @@ ignore =
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
per-file-ignores =
*/__init__.py: F401
# Specify the list of error codes you wish Flake8 to report.
select =
E,
...
...
deepspeech/exps/deepspeech2/bin/export.py
浏览文件 @
561d5cf0
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
if
args
.
model_type
is
None
:
...
...
deepspeech/exps/deepspeech2/bin/test.py
浏览文件 @
561d5cf0
...
...
@@ -31,6 +31,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
if
args
.
model_type
is
None
:
...
...
deepspeech/exps/u2/bin/alignment.py
浏览文件 @
561d5cf0
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/export.py
浏览文件 @
561d5cf0
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/test.py
浏览文件 @
561d5cf0
...
...
@@ -34,6 +34,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_kaldi/bin/test.py
浏览文件 @
561d5cf0
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
"""Evaluation for U2 model."""
import
cProfile
from
yacs.config
import
CfgNode
from
deepspeech.training.cli
import
default_argument_parser
...
...
@@ -54,6 +55,14 @@ if __name__ == "__main__":
type
=
str
,
default
=
'test'
,
help
=
'run mode, e.g. test, align, export'
)
parser
.
add_argument
(
'--dict-path'
,
type
=
str
,
default
=
None
,
help
=
'dict path.'
)
# save asr result to
parser
.
add_argument
(
"--result-file"
,
type
=
str
,
help
=
"path of save the asr result"
)
# save jit model to
parser
.
add_argument
(
"--export-path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
561d5cf0
...
...
@@ -25,6 +25,8 @@ import paddle
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
deepspeech.frontend.featurizer
import
TextFeaturizer
from
deepspeech.frontend.utility
import
load_dict
from
deepspeech.io.dataloader
import
BatchDataLoader
from
deepspeech.models.u2
import
U2Model
from
deepspeech.training.optimizer
import
OptimizerFactory
...
...
@@ -80,8 +82,8 @@ class U2Trainer(Trainer):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
# loss div by `batch_size * accum_grad`
...
...
@@ -124,6 +126,7 @@ class U2Trainer(Trainer):
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
...
...
@@ -305,10 +308,8 @@ class U2Trainer(Trainer):
model_conf
.
output_dim
=
self
.
train_loader
.
vocab_size
model_conf
.
freeze
()
model
=
U2Model
.
from_config
(
model_conf
)
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
...
...
@@ -379,13 +380,13 @@ class U2Tester(U2Trainer):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
def
ordid2token
(
self
,
texts
,
texts_len
):
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
trans
=
[]
for
text
,
n
in
zip
(
texts
,
texts_len
):
n
=
n
.
numpy
().
item
()
ids
=
text
[:
n
]
trans
.
append
(
''
.
join
([
chr
(
i
)
for
i
in
ids
]
))
trans
.
append
(
text_feature
.
defeaturize
(
ids
.
numpy
().
tolist
()
))
return
trans
def
compute_metrics
(
self
,
...
...
@@ -401,8 +402,11 @@ class U2Tester(U2Trainer):
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
start_time
=
time
.
time
()
text_feature
=
self
.
test_loader
.
collate_fn
.
text_feature
target_transcripts
=
self
.
ordid2token
(
texts
,
texts_len
)
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
text_feature
)
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
,
...
...
@@ -450,7 +454,7 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
test_loader
.
collate_fn
.
stride_ms
stride_ms
=
self
.
config
.
collator
.
stride_ms
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
num_frames
=
0.0
...
...
@@ -525,8 +529,9 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
logger
.
info
(
f
"Align Total Examples:
{
len
(
self
.
align_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
config
.
collate
.
stride_ms
token_dict
=
self
.
align_loader
.
collate_fn
.
vocab_list
stride_ms
=
self
.
config
.
collater
.
stride_ms
token_dict
=
self
.
args
.
char_list
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
# one example in batch
for
i
,
batch
in
enumerate
(
self
.
align_loader
):
...
...
@@ -613,6 +618,11 @@ class U2Tester(U2Trainer):
except
KeyboardInterrupt
:
sys
.
exit
(
-
1
)
def
setup_dict
(
self
):
# load dictionary for debug log
self
.
args
.
char_list
=
load_dict
(
self
.
args
.
dict_path
,
"maskctc"
in
self
.
args
.
model_name
)
def
setup
(
self
):
"""Setup the experiment.
"""
...
...
@@ -624,6 +634,8 @@ class U2Tester(U2Trainer):
self
.
setup_dataloader
()
self
.
setup_model
()
self
.
setup_dict
()
self
.
iteration
=
0
self
.
epoch
=
0
...
...
deepspeech/exps/u2_st/bin/export.py
浏览文件 @
561d5cf0
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_st/bin/test.py
浏览文件 @
561d5cf0
...
...
@@ -34,6 +34,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/frontend/featurizer/__init__.py
浏览文件 @
561d5cf0
...
...
@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.audio_featurizer
import
AudioFeaturizer
#noqa: F401
from
.speech_featurizer
import
SpeechFeaturizer
from
.text_featurizer
import
TextFeaturizer
deepspeech/frontend/featurizer/audio_featurizer.py
浏览文件 @
561d5cf0
...
...
@@ -18,7 +18,7 @@ from python_speech_features import logfbank
from
python_speech_features
import
mfcc
class
AudioFeaturizer
(
object
):
class
AudioFeaturizer
():
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
...
...
deepspeech/frontend/featurizer/speech_featurizer.py
浏览文件 @
561d5cf0
...
...
@@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
from
deepspeech.frontend.featurizer.text_featurizer
import
TextFeaturizer
class
SpeechFeaturizer
(
object
):
class
SpeechFeaturizer
():
"""Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment.
...
...
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
561d5cf0
...
...
@@ -14,12 +14,19 @@
"""Contains the text featurizer class."""
import
sentencepiece
as
spm
from
deepspeech.frontend.utility
import
EOS
from
deepspeech.frontend.utility
import
UNK
from
..utility
import
EOS
from
..utility
import
load_dict
from
..utility
import
UNK
__all__
=
[
"TextFeaturizer"
]
class
TextFeaturizer
(
object
):
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
):
class
TextFeaturizer
():
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
,
maskctc
=
False
):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
...
...
@@ -34,11 +41,12 @@ class TextFeaturizer(object):
assert
unit_type
in
(
'char'
,
'spm'
,
'word'
)
self
.
unit_type
=
unit_type
self
.
unk
=
UNK
self
.
maskctc
=
maskctc
if
vocab_filepath
:
self
.
_vocab_dict
,
self
.
_id2token
,
self
.
_vocab_list
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
)
self
.
unk_id
=
self
.
_vocab_list
.
index
(
self
.
unk
)
self
.
eos_id
=
self
.
_vocab_list
.
index
(
EOS
)
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
self
.
unk_id
,
self
.
eos_id
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
,
maskctc
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
if
unit_type
==
'spm'
:
spm_model
=
spm_model_prefix
+
'.model'
...
...
@@ -67,7 +75,7 @@ class TextFeaturizer(object):
"""Convert text string to a list of token indices.
Args:
text (str): Text
to process
.
text (str): Text.
Returns:
List[int]: List of token indices.
...
...
@@ -75,8 +83,8 @@ class TextFeaturizer(object):
tokens
=
self
.
tokenize
(
text
)
ids
=
[]
for
token
in
tokens
:
token
=
token
if
token
in
self
.
_
vocab_dict
else
self
.
unk
ids
.
append
(
self
.
_
vocab_dict
[
token
])
token
=
token
if
token
in
self
.
vocab_dict
else
self
.
unk
ids
.
append
(
self
.
vocab_dict
[
token
])
return
ids
def
defeaturize
(
self
,
idxs
):
...
...
@@ -87,7 +95,7 @@ class TextFeaturizer(object):
idxs (List[int]): List of token indices.
Returns:
str: Text
to process
.
str: Text.
"""
tokens
=
[]
for
idx
in
idxs
:
...
...
@@ -97,33 +105,6 @@ class TextFeaturizer(object):
text
=
self
.
detokenize
(
tokens
)
return
text
@
property
def
vocab_size
(
self
):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return
len
(
self
.
_vocab_list
)
@
property
def
vocab_list
(
self
):
"""Return the vocabulary in list.
Returns:
List[str]: tokens.
"""
return
self
.
_vocab_list
@
property
def
vocab_dict
(
self
):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]: token str -> int
"""
return
self
.
_vocab_dict
def
char_tokenize
(
self
,
text
):
"""Character tokenizer.
...
...
@@ -206,14 +187,16 @@ class TextFeaturizer(object):
return
decode
(
tokens
)
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
):
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
:
str
,
maskctc
:
bool
):
"""Load vocabulary from file."""
vocab_lines
=
[]
with
open
(
vocab_filepath
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
vocab_lines
.
extend
(
file
.
readlines
())
vocab_list
=
[
line
[:
-
1
]
for
line
in
vocab_lines
]
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
assert
vocab_list
is
not
None
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
token2id
=
dict
(
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
return
token2id
,
id2token
,
vocab_list
unk_id
=
vocab_list
.
index
(
UNK
)
eos_id
=
vocab_list
.
index
(
EOS
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
deepspeech/frontend/utility.py
浏览文件 @
561d5cf0
...
...
@@ -15,6 +15,9 @@
import
codecs
import
json
import
math
from
typing
import
List
from
typing
import
Optional
from
typing
import
Text
import
numpy
as
np
...
...
@@ -23,16 +26,35 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"load_
cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to_dbfs"
,
"max
_dbfs"
,
"m
ean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS"
,
"EOS"
,
"UNK
"
,
"
BLANK
"
"load_
dict"
,
"load_cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to
_dbfs"
,
"m
ax_dbfs"
,
"mean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS
"
,
"
EOS"
,
"UNK"
,
"BLANK"
,
"MASKCTC
"
]
IGNORE_ID
=
-
1
SOS
=
"<sos/eos>"
# `sos` and `eos` using same token
SOS
=
"<eos>"
EOS
=
SOS
UNK
=
"<unk>"
BLANK
=
"<blank>"
MASKCTC
=
"<mask>"
def
load_dict
(
dict_path
:
Optional
[
Text
],
maskctc
=
False
)
->
Optional
[
List
[
Text
]]:
if
dict_path
is
None
:
return
None
with
open
(
dict_path
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
char_list
=
[
entry
.
split
(
" "
)[
0
]
for
entry
in
dictionary
]
if
BLANK
not
in
char_list
:
char_list
.
insert
(
0
,
BLANK
)
if
EOS
not
in
char_list
:
char_list
.
append
(
EOS
)
# for non-autoregressive maskctc model
if
maskctc
and
MASKCTC
not
in
char_list
:
char_list
.
append
(
MASKCTC
)
return
char_list
def
read_manifest
(
...
...
@@ -47,12 +69,20 @@ def read_manifest(
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
...
...
deepspeech/training/cli.py
浏览文件 @
561d5cf0
...
...
@@ -47,18 +47,11 @@ def default_argument_parser():
# data and output
parser
.
add_argument
(
"--config"
,
metavar
=
"FILE"
,
help
=
"path of the config file to overwrite to default config with."
)
parser
.
add_argument
(
"--dump-config"
,
metavar
=
"FILE"
,
help
=
"dump config to yaml file."
)
# parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
parser
.
add_argument
(
"--output"
,
metavar
=
"OUTPUT_DIR"
,
help
=
"path to save checkpoint and logs."
)
# load from saved checkpoint
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
help
=
"path of the checkpoint to load"
)
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
# running
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
choices
=
[
"cpu"
,
"gpu"
],
help
=
"device type to use, cpu and gpu are supported."
)
...
...
examples/aishell/s0/conf/augmentation.json
浏览文件 @
561d5cf0
...
...
@@ -33,4 +33,4 @@
},
"prob"
:
1.0
}
]
\ No newline at end of file
]
examples/librispeech/s2/conf/transformer.yaml
浏览文件 @
561d5cf0
...
...
@@ -3,17 +3,11 @@ data:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test-clean
min_input_len
:
0.5
# second
max_input_len
:
20.0
# second
min_output_len
:
0.0
# tokens
max_output_len
:
400.0
# tokens
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
collator
:
vocab_filepath
:
data/
vocab
.txt
vocab_filepath
:
data/
train_960_unigram5000_units
.txt
unit_type
:
'
spm'
spm_model_prefix
:
'
data/
bpe_unigram_
5000'
spm_model_prefix
:
'
data/
train_960_unigram
5000'
mean_std_filepath
:
"
"
augmentation_config
:
conf/augmentation.json
batch_size
:
64
...
...
examples/librispeech/s2/local/align.sh
浏览文件 @
561d5cf0
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path
dict_path
ckpt_path_prefix"
exit
-1
fi
...
...
@@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then
device
=
cpu
fi
config_path
=
$1
ckpt_prefix
=
$2
dict_path
=
$2
ckpt_prefix
=
$3
batch_size
=
1
output_dir
=
${
ckpt_prefix
}
...
...
@@ -22,11 +23,13 @@ mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
'align'
\
--model-name
'u2_kaldi'
\
--run-mode
'align'
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result
_
file
${
output_dir
}
/
${
type
}
.align
\
--result
-
file
${
output_dir
}
/
${
type
}
.align
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.batch_size
${
batch_size
}
...
...
examples/librispeech/s2/local/export.sh
浏览文件 @
561d5cf0
...
...
@@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then
fi
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
'export'
\
--model-name
'u2_kaldi'
\
--run-mode
'export'
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
...
...
examples/librispeech/s2/local/test.sh
浏览文件 @
561d5cf0
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path
dict_path
ckpt_path_prefix"
exit
-1
fi
...
...
@@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then
fi
config_path
=
$1
ckpt_prefix
=
$2
dict_path
=
$2
ckpt_prefix
=
$3
chunk_mode
=
false
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
...
...
@@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do
batch_size
=
64
fi
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
test
\
--model-name
u2_kaldi
\
--run-mode
test
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result
_
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--result
-
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
...
...
@@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo
"decoding
${
type
}
"
batch_size
=
1
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
test
\
--model-name
u2_kaldi
\
--run-mode
test
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result
_
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--result
-
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
...
...
examples/librispeech/s2/run.sh
浏览文件 @
561d5cf0
...
...
@@ -5,6 +5,7 @@ source path.sh
stage
=
0
stop_stage
=
100
conf_path
=
conf/transformer.yaml
dict_path
=
data/train_960_unigram5000_units.txt
avg_num
=
5
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
...
@@ -29,12 +30,12 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES
=
0 ./local/align.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/align.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
...
...
examples/tiny/s0/conf/augmentation.json
浏览文件 @
561d5cf0
...
...
@@ -29,8 +29,7 @@
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
"replace_with_zero"
:
true
},
"prob"
:
1.0
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录