Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
0b7e0d1e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0b7e0d1e
编写于
12月 07, 2021
作者:
K
KP
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update tags of pretrained_models.
上级
f8204c98
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
35 addition
and
24 deletion
+35
-24
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+17
-11
paddlespeech/cli/cls/infer.py
paddlespeech/cli/cls/infer.py
+18
-13
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
0b7e0d1e
...
@@ -39,7 +39,11 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
...
@@ -39,7 +39,11 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__
=
[
'ASRExecutor'
]
__all__
=
[
'ASRExecutor'
]
pretrained_models
=
{
pretrained_models
=
{
"wenetspeech_zh_16k"
:
{
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k"
:
{
'url'
:
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz'
,
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz'
,
'md5'
:
'md5'
:
...
@@ -49,7 +53,7 @@ pretrained_models = {
...
@@ -49,7 +53,7 @@ pretrained_models = {
'ckpt_path'
:
'ckpt_path'
:
'exp/conformer/checkpoints/wenetspeech'
,
'exp/conformer/checkpoints/wenetspeech'
,
},
},
"transformer_
zh_
16k"
:
{
"transformer_
aishell-zh-
16k"
:
{
'url'
:
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz'
,
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz'
,
'md5'
:
'md5'
:
...
@@ -83,7 +87,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -83,7 +87,7 @@ class ASRExecutor(BaseExecutor):
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--model'
,
'--model'
,
type
=
str
,
type
=
str
,
default
=
'wenetspeech'
,
default
=
'
conformer_
wenetspeech'
,
help
=
'Choose model type of asr task.'
)
help
=
'Choose model type of asr task.'
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--lang'
,
'--lang'
,
...
@@ -143,7 +147,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -143,7 +147,7 @@ class ASRExecutor(BaseExecutor):
if
cfg_path
is
None
or
ckpt_path
is
None
:
if
cfg_path
is
None
or
ckpt_path
is
None
:
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'
_'
+
lang
+
'_
'
+
sample_rate_str
tag
=
model_type
+
'
-'
+
lang
+
'-
'
+
sample_rate_str
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
res_path
=
res_path
self
.
res_path
=
res_path
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
...
@@ -165,7 +169,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -165,7 +169,7 @@ class ASRExecutor(BaseExecutor):
self
.
config
.
decoding
.
decoding_method
=
"attention_rescoring"
self
.
config
.
decoding
.
decoding_method
=
"attention_rescoring"
with
UpdateConfig
(
self
.
config
):
with
UpdateConfig
(
self
.
config
):
if
model_type
==
"ds2_online"
or
model_type
==
"ds2_offline"
:
if
"ds2_online"
in
model_type
or
"ds2_offline"
in
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
config
.
collator
.
vocab_filepath
=
os
.
path
.
join
(
self
.
config
.
collator
.
vocab_filepath
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
vocab_filepath
)
res_path
,
self
.
config
.
collator
.
vocab_filepath
)
...
@@ -178,7 +182,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -178,7 +182,7 @@ class ASRExecutor(BaseExecutor):
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
self
.
config
.
model
.
input_dim
=
self
.
collate_fn_test
.
feature_size
self
.
config
.
model
.
input_dim
=
self
.
collate_fn_test
.
feature_size
self
.
config
.
model
.
output_dim
=
text_feature
.
vocab_size
self
.
config
.
model
.
output_dim
=
text_feature
.
vocab_size
elif
model_type
==
"conformer"
or
model_type
==
"transformer"
or
model_type
==
"wenetspeech"
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
self
.
config
.
collator
.
vocab_filepath
=
os
.
path
.
join
(
self
.
config
.
collator
.
vocab_filepath
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
vocab_filepath
)
res_path
,
self
.
config
.
collator
.
vocab_filepath
)
self
.
config
.
collator
.
augmentation_config
=
os
.
path
.
join
(
self
.
config
.
collator
.
augmentation_config
=
os
.
path
.
join
(
...
@@ -196,7 +200,9 @@ class ASRExecutor(BaseExecutor):
...
@@ -196,7 +200,9 @@ class ASRExecutor(BaseExecutor):
raise
Exception
(
"wrong type"
)
raise
Exception
(
"wrong type"
)
# Enter the path of model root
# Enter the path of model root
model_class
=
dynamic_import
(
model_type
,
model_alias
)
model_name
=
''
.
join
(
model_type
.
split
(
'_'
)[:
-
1
])
# model_type: {model_name}_{dataset}
model_class
=
dynamic_import
(
model_name
,
model_alias
)
model_conf
=
self
.
config
.
model
model_conf
=
self
.
config
.
model
logger
.
info
(
model_conf
)
logger
.
info
(
model_conf
)
model
=
model_class
.
from_config
(
model_conf
)
model
=
model_class
.
from_config
(
model_conf
)
...
@@ -217,7 +223,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -217,7 +223,7 @@ class ASRExecutor(BaseExecutor):
logger
.
info
(
"Preprocess audio_file:"
+
audio_file
)
logger
.
info
(
"Preprocess audio_file:"
+
audio_file
)
# Get the object for feature extraction
# Get the object for feature extraction
if
model_type
==
"ds2_online"
or
model_type
==
"ds2_offline"
:
if
"ds2_online"
in
model_type
or
"ds2_offline"
in
model_type
:
audio
,
_
=
self
.
collate_fn_test
.
process_utterance
(
audio
,
_
=
self
.
collate_fn_test
.
process_utterance
(
audio_file
=
audio_file
,
transcript
=
" "
)
audio_file
=
audio_file
,
transcript
=
" "
)
audio_len
=
audio
.
shape
[
0
]
audio_len
=
audio
.
shape
[
0
]
...
@@ -229,7 +235,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -229,7 +235,7 @@ class ASRExecutor(BaseExecutor):
self
.
_inputs
[
"audio_len"
]
=
audio_len
self
.
_inputs
[
"audio_len"
]
=
audio_len
logger
.
info
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
logger
.
info
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
elif
model_type
==
"conformer"
or
model_type
==
"transformer"
or
model_type
==
"wenetspeech"
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
logger
.
info
(
"get the preprocess conf"
)
logger
.
info
(
"get the preprocess conf"
)
preprocess_conf_file
=
self
.
config
.
collator
.
augmentation_config
preprocess_conf_file
=
self
.
config
.
collator
.
augmentation_config
# redirect the cmvn path
# redirect the cmvn path
...
@@ -293,7 +299,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -293,7 +299,7 @@ class ASRExecutor(BaseExecutor):
cfg
=
self
.
config
.
decoding
cfg
=
self
.
config
.
decoding
audio
=
self
.
_inputs
[
"audio"
]
audio
=
self
.
_inputs
[
"audio"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
if
model_type
==
"ds2_online"
or
model_type
==
"ds2_offline"
:
if
"ds2_online"
in
model_type
or
"ds2_offline"
in
model_type
:
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio
,
audio_len
,
audio_len
,
...
@@ -308,7 +314,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -308,7 +314,7 @@ class ASRExecutor(BaseExecutor):
num_processes
=
cfg
.
num_proc_bsearch
)
num_processes
=
cfg
.
num_proc_bsearch
)
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
elif
model_type
==
"conformer"
or
model_type
==
"transformer"
or
model_type
==
"wenetspeech"
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio
,
audio_len
,
audio_len
,
...
...
paddlespeech/cli/cls/infer.py
浏览文件 @
0b7e0d1e
...
@@ -33,21 +33,25 @@ from paddlespeech.s2t.utils.dynamic_import import dynamic_import
...
@@ -33,21 +33,25 @@ from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__
=
[
'CLSExecutor'
]
__all__
=
[
'CLSExecutor'
]
pretrained_models
=
{
pretrained_models
=
{
"panns_cnn6"
:
{
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"panns_cnn6-32k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz'
,
'url'
:
'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz'
,
'md5'
:
'
051b30c56bcb9a3dd67bc205cc12ffd2
'
,
'md5'
:
'
4cf09194a95df024fd12f84712cf0f9c
'
,
'cfg_path'
:
'panns.yaml'
,
'cfg_path'
:
'panns.yaml'
,
'ckpt_path'
:
'cnn6.pdparams'
,
'ckpt_path'
:
'cnn6.pdparams'
,
'label_file'
:
'audioset_labels.txt'
,
'label_file'
:
'audioset_labels.txt'
,
},
},
"panns_cnn10"
:
{
"panns_cnn10
-32k
"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz'
,
'url'
:
'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz'
,
'md5'
:
'
97c6f25587685379b1ebcd4c1f624927
'
,
'md5'
:
'
cb8427b22176cc2116367d14847f5413
'
,
'cfg_path'
:
'panns.yaml'
,
'cfg_path'
:
'panns.yaml'
,
'ckpt_path'
:
'cnn10.pdparams'
,
'ckpt_path'
:
'cnn10.pdparams'
,
'label_file'
:
'audioset_labels.txt'
,
'label_file'
:
'audioset_labels.txt'
,
},
},
"panns_cnn14"
:
{
"panns_cnn14
-32k
"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz'
,
'url'
:
'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz'
,
'md5'
:
'e3b9b5614a1595001161d0ab95edee97'
,
'md5'
:
'e3b9b5614a1595001161d0ab95edee97'
,
'cfg_path'
:
'panns.yaml'
,
'cfg_path'
:
'panns.yaml'
,
...
@@ -76,7 +80,7 @@ class CLSExecutor(BaseExecutor):
...
@@ -76,7 +80,7 @@ class CLSExecutor(BaseExecutor):
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--model'
,
'--model'
,
type
=
str
,
type
=
str
,
default
=
'panns_cnn1
4
'
,
default
=
'panns_cnn1
0
'
,
help
=
'Choose model type of cls task.'
)
help
=
'Choose model type of cls task.'
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--config'
,
'--config'
,
...
@@ -133,13 +137,14 @@ class CLSExecutor(BaseExecutor):
...
@@ -133,13 +137,14 @@ class CLSExecutor(BaseExecutor):
return
return
if
label_file
is
None
or
ckpt_path
is
None
:
if
label_file
is
None
or
ckpt_path
is
None
:
self
.
res_path
=
self
.
_get_pretrained_path
(
model_type
)
# panns_cnn14
tag
=
model_type
+
'-'
+
'32k'
# panns_cnn14-32k
self
.
cfg_path
=
os
.
path
.
join
(
self
.
res_path
=
self
.
_get_pretrained_path
(
tag
)
self
.
res_path
,
pretrained_models
[
model_type
][
'cfg_path'
])
self
.
cfg_path
=
os
.
path
.
join
(
self
.
res_path
,
self
.
label_file
=
os
.
path
.
join
(
pretrained_models
[
tag
][
'cfg_path'
])
self
.
res_path
,
pretrained_models
[
model_type
][
'label_file'
])
self
.
label_file
=
os
.
path
.
join
(
self
.
res_path
,
self
.
ckpt_path
=
os
.
path
.
join
(
pretrained_models
[
tag
][
'label_file'
])
self
.
res_path
,
pretrained_models
[
model_type
][
'ckpt_path'
])
self
.
ckpt_path
=
os
.
path
.
join
(
self
.
res_path
,
pretrained_models
[
tag
][
'ckpt_path'
])
else
:
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
label_file
=
os
.
path
.
abspath
(
label_file
)
self
.
label_file
=
os
.
path
.
abspath
(
label_file
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录