Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5c9e4caa
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5c9e4caa
编写于
1月 10, 2022
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add en and decode_method for cli/asr, test=asr
上级
50ceca9d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
225 addition
and
53 deletion
+225
-53
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+42
-49
paddlespeech/s2t/frontend/normalizer.py
paddlespeech/s2t/frontend/normalizer.py
+9
-4
utils/generate_infer_yaml.py
utils/generate_infer_yaml.py
+174
-0
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
5c9e4caa
...
...
@@ -46,19 +46,29 @@ pretrained_models = {
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/
conformer
.model.tar.gz'
,
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/
asr1_conformer_wenetspeech_ckpt_0.1.1
.model.tar.gz'
,
'md5'
:
'
54e7a558a6e020c2f5fb224874943f97
'
,
'
b9afd8285ff5b2596bf96afab656b02f
'
,
'cfg_path'
:
'conf/conformer.yaml'
,
'conf/conformer
_infer
.yaml'
,
'ckpt_path'
:
'exp/conformer/checkpoints/wenetspeech'
,
},
"transformer_librispeech-en-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz'
,
'md5'
:
'c95b9997f5f81478b32879a38532913d'
,
'cfg_path'
:
'conf/transformer_infer.yaml'
,
'ckpt_path'
:
'exp/transformer/checkpoints/avg_10'
,
},
}
model_alias
=
{
"d
s2_
offline"
:
"paddlespeech.s2t.models.ds2:DeepSpeech2Model"
,
"d
s2_
online"
:
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"
,
"d
eepspeech2
offline"
:
"paddlespeech.s2t.models.ds2:DeepSpeech2Model"
,
"d
eepspeech2
online"
:
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"
,
"conformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"transformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"wenetspeech"
:
"paddlespeech.s2t.models.u2:U2Model"
,
...
...
@@ -85,7 +95,7 @@ class ASRExecutor(BaseExecutor):
'--lang'
,
type
=
str
,
default
=
'zh'
,
help
=
'Choose model language. zh or en'
)
help
=
'Choose model language. zh or en
, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]
'
)
self
.
parser
.
add_argument
(
"--sample_rate"
,
type
=
int
,
...
...
@@ -97,6 +107,12 @@ class ASRExecutor(BaseExecutor):
type
=
str
,
default
=
None
,
help
=
'Config of asr task. Use deault config when it is None.'
)
self
.
parser
.
add_argument
(
'--decode_method'
,
type
=
str
,
default
=
'attention_rescoring'
,
choices
=
[
'ctc_greedy_search'
,
'ctc_prefix_beam_search'
,
'attention'
,
'attention_rescoring'
],
help
=
'only support transformer and conformer model'
)
self
.
parser
.
add_argument
(
'--ckpt_path'
,
type
=
str
,
...
...
@@ -136,6 +152,7 @@ class ASRExecutor(BaseExecutor):
lang
:
str
=
'zh'
,
sample_rate
:
int
=
16000
,
cfg_path
:
Optional
[
os
.
PathLike
]
=
None
,
decode_method
:
str
=
'attention_rescoring'
,
ckpt_path
:
Optional
[
os
.
PathLike
]
=
None
):
"""
Init model and other resources from a specific path.
...
...
@@ -165,45 +182,30 @@ class ASRExecutor(BaseExecutor):
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
self
.
config
.
decoding
.
decoding_method
=
"attention_rescoring"
with
UpdateConfig
(
self
.
config
):
if
"d
s2_online"
in
model_type
or
"ds2_
offline"
in
model_type
:
if
"d
eepspeech2online"
in
model_type
or
"deepspeech2
offline"
in
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
config
.
collator
.
vocab_filepath
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
vocab_filepath
)
self
.
config
.
collator
.
mean_std_filepath
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
cmvn_path
)
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
res_path
,
self
.
config
.
decode
.
lang_model_path
)
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
self
.
config
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
self
.
config
.
model
.
input_dim
=
self
.
collate_fn_test
.
feature_size
self
.
config
.
model
.
output_dim
=
self
.
text_feature
.
vocab_size
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
vocab
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
self
.
config
.
collator
.
vocab_filepath
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
vocab_filepath
)
self
.
config
.
collator
.
augmentation_config
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
augmentation_config
)
self
.
config
.
collator
.
spm_model_prefix
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
spm_model_prefix
)
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
self
.
config
.
model
.
input_dim
=
self
.
config
.
collator
.
feat_dim
self
.
config
.
model
.
output_dim
=
self
.
text_feature
.
vocab_size
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
config
.
decode
.
decoding_method
=
decode_method
else
:
raise
Exception
(
"wrong type"
)
# Enter the path of model root
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
model_class
=
dynamic_import
(
model_name
,
model_alias
)
model_conf
=
self
.
config
.
model
logger
.
info
(
model_conf
)
model_conf
=
self
.
config
model
=
model_class
.
from_config
(
model_conf
)
self
.
model
=
model
self
.
model
.
eval
()
...
...
@@ -222,7 +224,7 @@ class ASRExecutor(BaseExecutor):
logger
.
info
(
"Preprocess audio_file:"
+
audio_file
)
# Get the object for feature extraction
if
"d
s2_online"
in
model_type
or
"ds2_
offline"
in
model_type
:
if
"d
eepspeech2online"
in
model_type
or
"deepspeech2
offline"
in
model_type
:
audio
,
_
=
self
.
collate_fn_test
.
process_utterance
(
audio_file
=
audio_file
,
transcript
=
" "
)
audio_len
=
audio
.
shape
[
0
]
...
...
@@ -236,18 +238,7 @@ class ASRExecutor(BaseExecutor):
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
logger
.
info
(
"get the preprocess conf"
)
preprocess_conf_file
=
self
.
config
.
collator
.
augmentation_config
# redirect the cmvn path
with
io
.
open
(
preprocess_conf_file
,
encoding
=
"utf-8"
)
as
f
:
preprocess_conf
=
yaml
.
safe_load
(
f
)
for
idx
,
process
in
enumerate
(
preprocess_conf
[
"process"
]):
if
process
[
'type'
]
==
"cmvn_json"
:
preprocess_conf
[
"process"
][
idx
][
"cmvn_path"
]
=
os
.
path
.
join
(
self
.
res_path
,
preprocess_conf
[
"process"
][
idx
][
"cmvn_path"
])
break
logger
.
info
(
preprocess_conf
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocessing
=
Transformation
(
preprocess_conf
)
logger
.
info
(
"read the audio file"
)
...
...
@@ -289,10 +280,10 @@ class ASRExecutor(BaseExecutor):
Model inference and result stored in self.output.
"""
cfg
=
self
.
config
.
decod
ing
cfg
=
self
.
config
.
decod
e
audio
=
self
.
_inputs
[
"audio"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
if
"d
s2_online"
in
model_type
or
"ds2_
offline"
in
model_type
:
if
"d
eepspeech2online"
in
model_type
or
"deepspeech2
offline"
in
model_type
:
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
,
...
...
@@ -414,12 +405,13 @@ class ASRExecutor(BaseExecutor):
config
=
parser_args
.
config
ckpt_path
=
parser_args
.
ckpt_path
audio_file
=
parser_args
.
input
decode_method
=
parser_args
.
decode_method
force_yes
=
parser_args
.
yes
device
=
parser_args
.
device
try
:
res
=
self
(
audio_file
,
model
,
lang
,
sample_rate
,
config
,
ckpt_path
,
force_yes
,
device
)
decode_method
,
force_yes
,
device
)
logger
.
info
(
'ASR Result: {}'
.
format
(
res
))
return
True
except
Exception
as
e
:
...
...
@@ -434,6 +426,7 @@ class ASRExecutor(BaseExecutor):
sample_rate
:
int
=
16000
,
config
:
os
.
PathLike
=
None
,
ckpt_path
:
os
.
PathLike
=
None
,
decode_method
:
str
=
'attention_rescoring'
,
force_yes
:
bool
=
False
,
device
=
paddle
.
get_device
()):
"""
...
...
@@ -442,7 +435,7 @@ class ASRExecutor(BaseExecutor):
audio_file
=
os
.
path
.
abspath
(
audio_file
)
self
.
_check
(
audio_file
,
sample_rate
,
force_yes
)
paddle
.
set_device
(
device
)
self
.
_init_from_path
(
model
,
lang
,
sample_rate
,
config
,
ckpt_path
)
self
.
_init_from_path
(
model
,
lang
,
sample_rate
,
config
,
decode_method
,
ckpt_path
)
self
.
preprocess
(
model
,
audio_file
)
self
.
infer
(
model
)
res
=
self
.
postprocess
()
# Retrieve result of asr.
...
...
paddlespeech/s2t/frontend/normalizer.py
浏览文件 @
5c9e4caa
...
...
@@ -117,7 +117,8 @@ class FeatureNormalizer(object):
self
.
_compute_mean_std
(
manifest_path
,
featurize_func
,
num_samples
,
num_workers
)
else
:
self
.
_read_mean_std_from_file
(
mean_std_filepath
)
mean_std
=
mean_std_filepath
self
.
_read_mean_std_from_file
(
mean_std
)
def
apply
(
self
,
features
):
"""Normalize features to be of zero mean and unit stddev.
...
...
@@ -131,10 +132,14 @@ class FeatureNormalizer(object):
"""
return
(
features
-
self
.
_mean
)
*
self
.
_istd
def
_read_mean_std_from_file
(
self
,
filepath
,
eps
=
1e-20
):
def
_read_mean_std_from_file
(
self
,
mean_std
,
eps
=
1e-20
):
"""Load mean and std from file."""
filetype
=
filepath
.
split
(
"."
)[
-
1
]
mean
,
istd
=
load_cmvn
(
filepath
,
filetype
=
filetype
)
if
isinstance
(
mean_std
,
list
):
mean
=
mean_std
[
0
][
'cmvn_stats'
][
'mean'
]
istd
=
mean_std
[
0
][
'cmvn_stats'
][
'istd'
]
else
:
filetype
=
mean_std
.
split
(
"."
)[
-
1
]
mean
,
istd
=
load_cmvn
(
mean_std
,
filetype
=
filetype
)
self
.
_mean
=
np
.
expand_dims
(
mean
,
axis
=
0
)
self
.
_istd
=
np
.
expand_dims
(
istd
,
axis
=
0
)
...
...
utils/generate_infer_yaml.py
0 → 100644
浏览文件 @
5c9e4caa
#!/usr/bin/env python3
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
'''
Merge training configs into a single inference config.
'''
import
yaml
import
json
import
os
import
argparse
import
math
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.frontend.utility
import
load_dict
from
contextlib
import
redirect_stdout
def
save
(
save_path
,
config
):
with
open
(
save_path
,
'w'
)
as
fp
:
with
redirect_stdout
(
fp
):
print
(
config
.
dump
())
def
load
(
save_path
):
config
=
CfgNode
(
new_allowed
=
True
)
config
.
merge_from_file
(
save_path
)
return
config
def
load_json
(
json_path
):
with
open
(
json_path
)
as
f
:
json_content
=
json
.
load
(
f
)
return
json_content
def
remove_config_part
(
config
,
key_list
):
if
len
(
key_list
)
==
0
:
return
for
i
in
range
(
len
(
key_list
)
-
1
):
config
=
config
[
key_list
[
i
]]
config
.
pop
(
key_list
[
-
1
])
def
load_cmvn_from_json
(
cmvn_stats
):
means
=
cmvn_stats
[
'mean_stat'
]
variance
=
cmvn_stats
[
'var_stat'
]
count
=
cmvn_stats
[
'frame_num'
]
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn_stats
=
{
"mean"
:
means
,
"istd"
:
variance
}
return
cmvn_stats
def
merge_configs
(
conf_path
=
"conf/conformer.yaml"
,
preprocess_path
=
"conf/preprocess.yaml"
,
decode_path
=
"conf/tuning/decode.yaml"
,
vocab_path
=
"data/vocab.txt"
,
cmvn_path
=
"data/mean_std.json"
,
save_path
=
"conf/conformer_infer.yaml"
,
):
# Load the configs
config
=
load
(
conf_path
)
decode_config
=
load
(
decode_path
)
vocab_list
=
load_dict
(
vocab_path
)
cmvn_stats
=
load_json
(
cmvn_path
)
if
os
.
path
.
exists
(
preprocess_path
):
preprocess_config
=
load
(
preprocess_path
)
for
idx
,
process
in
enumerate
(
preprocess_config
[
"process"
]):
if
process
[
'type'
]
==
"cmvn_json"
:
preprocess_config
[
"process"
][
idx
][
"cmvn_path"
]
=
cmvn_stats
break
config
.
preprocess_config
=
preprocess_config
else
:
cmvn_stats
=
load_cmvn_from_json
(
cmvn_stats
)
config
.
mean_std_filepath
=
[{
"cmvn_stats"
:
cmvn_stats
}]
config
.
augmentation_config
=
''
# Updata the config
config
.
vocab_filepath
=
vocab_list
config
.
input_dim
=
config
.
feat_dim
config
.
output_dim
=
len
(
config
.
vocab_filepath
)
config
.
decode
=
decode_config
# Remove some parts of the config
if
os
.
path
.
exists
(
preprocess_path
):
remove_list
=
[
"train_manifest"
,
"dev_manifest"
,
"test_manifest"
,
"n_epoch"
,
"accum_grad"
,
"global_grad_clip"
,
"optim"
,
"optim_conf"
,
"scheduler"
,
"scheduler_conf"
,
"log_interval"
,
"checkpoint"
,
"shuffle_method"
,
"weight_decay"
,
"ctc_grad_norm_type"
,
"minibatches"
,
"batch_bins"
,
"batch_count"
,
"batch_frames_in"
,
"batch_frames_inout"
,
"batch_frames_out"
,
"sortagrad"
,
"feat_dim"
,
"stride_ms"
,
"window_ms"
,
"batch_size"
,
"maxlen_in"
,
"maxlen_out"
,
]
else
:
remove_list
=
[
"train_manifest"
,
"dev_manifest"
,
"test_manifest"
,
"n_epoch"
,
"accum_grad"
,
"global_grad_clip"
,
"log_interval"
,
"checkpoint"
,
"lr"
,
"lr_decay"
,
"batch_size"
,
"shuffle_method"
,
"weight_decay"
,
"sortagrad"
,
"num_workers"
,
]
for
item
in
remove_list
:
try
:
remove_config_part
(
config
,
[
item
])
except
:
print
(
item
+
" "
+
"can not be removed"
)
# Save the config
save
(
save_path
,
config
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
'Config merge'
,
add_help
=
True
)
parser
.
add_argument
(
'--cfg_pth'
,
type
=
str
,
default
=
'conf/transformer.yaml'
,
help
=
'origin config file'
)
parser
.
add_argument
(
'--pre_pth'
,
type
=
str
,
default
=
"conf/preprocess.yaml"
,
help
=
''
)
parser
.
add_argument
(
'--dcd_pth'
,
type
=
str
,
default
=
"conf/tuninig/decode.yaml"
,
help
=
''
)
parser
.
add_argument
(
'--vb_pth'
,
type
=
str
,
default
=
"data/lang_char/vocab.txt"
,
help
=
''
)
parser
.
add_argument
(
'--cmvn_pth'
,
type
=
str
,
default
=
"data/mean_std.json"
,
help
=
''
)
parser
.
add_argument
(
'--save_pth'
,
type
=
str
,
default
=
"conf/transformer_infer.yaml"
,
help
=
''
)
parser_args
=
parser
.
parse_args
()
merge_configs
(
conf_path
=
parser_args
.
cfg_pth
,
preprocess_path
=
parser_args
.
pre_pth
,
vocab_path
=
parser_args
.
vb_pth
,
cmvn_path
=
parser_args
.
cmvn_pth
,
save_path
=
parser_args
.
save_pth
,
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录