Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
36c9eaa4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36c9eaa4
编写于
1月 04, 2022
作者:
老熊宝宝
提交者:
GitHub
1月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cache the TextFeaturizer instance for infer speed improvement. (#1260)
上级
50752f8b
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
10 addition
and
15 deletion
+10
-15
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+10
-15
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
36c9eaa4
...
@@ -174,12 +174,12 @@ class ASRExecutor(BaseExecutor):
...
@@ -174,12 +174,12 @@ class ASRExecutor(BaseExecutor):
self
.
config
.
collator
.
mean_std_filepath
=
os
.
path
.
join
(
self
.
config
.
collator
.
mean_std_filepath
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
cmvn_path
)
res_path
,
self
.
config
.
collator
.
cmvn_path
)
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
self
.
config
)
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
self
.
config
)
text_feature
=
TextFeaturizer
(
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
self
.
config
.
model
.
input_dim
=
self
.
collate_fn_test
.
feature_size
self
.
config
.
model
.
input_dim
=
self
.
collate_fn_test
.
feature_size
self
.
config
.
model
.
output_dim
=
text_feature
.
vocab_size
self
.
config
.
model
.
output_dim
=
self
.
text_feature
.
vocab_size
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
self
.
config
.
collator
.
vocab_filepath
=
os
.
path
.
join
(
self
.
config
.
collator
.
vocab_filepath
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
vocab_filepath
)
res_path
,
self
.
config
.
collator
.
vocab_filepath
)
...
@@ -187,12 +187,12 @@ class ASRExecutor(BaseExecutor):
...
@@ -187,12 +187,12 @@ class ASRExecutor(BaseExecutor):
res_path
,
self
.
config
.
collator
.
augmentation_config
)
res_path
,
self
.
config
.
collator
.
augmentation_config
)
self
.
config
.
collator
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
config
.
collator
.
spm_model_prefix
=
os
.
path
.
join
(
res_path
,
self
.
config
.
collator
.
spm_model_prefix
)
res_path
,
self
.
config
.
collator
.
spm_model_prefix
)
text_feature
=
TextFeaturizer
(
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
self
.
config
.
model
.
input_dim
=
self
.
config
.
collator
.
feat_dim
self
.
config
.
model
.
input_dim
=
self
.
config
.
collator
.
feat_dim
self
.
config
.
model
.
output_dim
=
text_feature
.
vocab_size
self
.
config
.
model
.
output_dim
=
self
.
text_feature
.
vocab_size
else
:
else
:
raise
Exception
(
"wrong type"
)
raise
Exception
(
"wrong type"
)
...
@@ -211,6 +211,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -211,6 +211,7 @@ class ASRExecutor(BaseExecutor):
model_dict
=
paddle
.
load
(
self
.
ckpt_path
)
model_dict
=
paddle
.
load
(
self
.
ckpt_path
)
self
.
model
.
set_state_dict
(
model_dict
)
self
.
model
.
set_state_dict
(
model_dict
)
def
preprocess
(
self
,
model_type
:
str
,
input
:
Union
[
str
,
os
.
PathLike
]):
def
preprocess
(
self
,
model_type
:
str
,
input
:
Union
[
str
,
os
.
PathLike
]):
"""
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input preprocess and return paddle.Tensor stored in self.input.
...
@@ -228,7 +229,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -228,7 +229,7 @@ class ASRExecutor(BaseExecutor):
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
audio_len
=
paddle
.
to_tensor
(
audio_len
)
audio_len
=
paddle
.
to_tensor
(
audio_len
)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
vocab_list
=
collate_fn_test
.
vocab_list
#
vocab_list = collate_fn_test.vocab_list
self
.
_inputs
[
"audio"
]
=
audio
self
.
_inputs
[
"audio"
]
=
audio
self
.
_inputs
[
"audio_len"
]
=
audio_len
self
.
_inputs
[
"audio_len"
]
=
audio_len
logger
.
info
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
logger
.
info
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
...
@@ -274,10 +275,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -274,10 +275,7 @@ class ASRExecutor(BaseExecutor):
audio_len
=
paddle
.
to_tensor
(
audio
.
shape
[
0
])
audio_len
=
paddle
.
to_tensor
(
audio
.
shape
[
0
])
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
self
.
_inputs
[
"audio"
]
=
audio
self
.
_inputs
[
"audio"
]
=
audio
self
.
_inputs
[
"audio_len"
]
=
audio_len
self
.
_inputs
[
"audio_len"
]
=
audio_len
logger
.
info
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
logger
.
info
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
...
@@ -290,10 +288,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -290,10 +288,7 @@ class ASRExecutor(BaseExecutor):
"""
"""
Model inference and result stored in self.output.
Model inference and result stored in self.output.
"""
"""
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
cfg
=
self
.
config
.
decoding
cfg
=
self
.
config
.
decoding
audio
=
self
.
_inputs
[
"audio"
]
audio
=
self
.
_inputs
[
"audio"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
...
@@ -301,7 +296,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -301,7 +296,7 @@ class ASRExecutor(BaseExecutor):
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio
,
audio_len
,
audio_len
,
text_feature
.
vocab_list
,
self
.
text_feature
.
vocab_list
,
decoding_method
=
cfg
.
decoding_method
,
decoding_method
=
cfg
.
decoding_method
,
lang_model_path
=
cfg
.
lang_model_path
,
lang_model_path
=
cfg
.
lang_model_path
,
beam_alpha
=
cfg
.
alpha
,
beam_alpha
=
cfg
.
alpha
,
...
@@ -316,7 +311,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -316,7 +311,7 @@ class ASRExecutor(BaseExecutor):
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio
,
audio_len
,
audio_len
,
text_feature
=
text_feature
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
decoding_method
=
cfg
.
decoding_method
,
beam_size
=
cfg
.
beam_size
,
beam_size
=
cfg
.
beam_size
,
ctc_weight
=
cfg
.
ctc_weight
,
ctc_weight
=
cfg
.
ctc_weight
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录