Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b9e3e493
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b9e3e493
编写于
6月 15, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor stream asr and fix ds2 stream bug
上级
bca014fd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
96 addition
and
71 deletion
+96
-71
demos/streaming_asr_server/test.sh
demos/streaming_asr_server/test.sh
+1
-1
paddlespeech/server/engine/asr/online/python/asr_engine.py
paddlespeech/server/engine/asr/online/python/asr_engine.py
+91
-69
paddlespeech/server/engine/engine_factory.py
paddlespeech/server/engine/engine_factory.py
+4
-1
未找到文件。
demos/streaming_asr_server/test.sh
浏览文件 @
b9e3e493
...
@@ -4,7 +4,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
...
@@ -4,7 +4,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# read the wav and pass it to only streaming asr service
# read the wav and pass it to only streaming asr service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
paddlespeech_client asr_online
--server_ip
127.0.0.1
--port
8
2
90
--input
./zh.wav
paddlespeech_client asr_online
--server_ip
127.0.0.1
--port
8
0
90
--input
./zh.wav
# read the wav and call streaming and punc service
# read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
...
...
paddlespeech/server/engine/asr/online/asr_engine.py
→
paddlespeech/server/engine/asr/online/
python/
asr_engine.py
浏览文件 @
b9e3e493
...
@@ -121,13 +121,14 @@ class PaddleASRConnectionHanddler:
...
@@ -121,13 +121,14 @@ class PaddleASRConnectionHanddler:
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
def
model_reset
(
self
):
def
model_reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
return
# cache for audio and feat
# cache for audio and feat
self
.
remained_wav
=
None
self
.
remained_wav
=
None
self
.
cached_feat
=
None
self
.
cached_feat
=
None
if
"deepspeech2"
in
self
.
model_type
:
return
## conformer
## conformer
# cache for conformer online
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
subsampling_cache
=
None
...
@@ -697,6 +698,67 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -697,6 +698,67 @@ class ASRServerExecutor(ASRExecutor):
self
.
task_resource
=
CommonTaskResource
(
self
.
task_resource
=
CommonTaskResource
(
task
=
'asr'
,
model_format
=
'dynamic'
,
inference_mode
=
'online'
)
task
=
'asr'
,
model_format
=
'dynamic'
,
inference_mode
=
'online'
)
def
update_config
(
self
)
->
None
:
if
"deepspeech2"
in
self
.
model_type
:
with
UpdateConfig
(
self
.
config
):
# download lm
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
MODEL_HOME
,
'language_model'
,
self
.
config
.
decode
.
lang_model_path
)
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
with
UpdateConfig
(
self
.
config
):
logger
.
info
(
"start to create the stream conformer asr engine"
)
# update the decoding method
if
self
.
decode_method
:
self
.
config
.
decode
.
decoding_method
=
self
.
decode_method
# update num_decoding_left_chunks
if
self
.
num_decoding_left_chunks
:
assert
self
.
num_decoding_left_chunks
==
-
1
or
self
.
num_decoding_left_chunks
>=
0
,
f
"num_decoding_left_chunks should be -1 or >=0"
self
.
config
.
decode
.
num_decoding_left_chunks
=
self
.
num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if
self
.
config
.
decode
.
decoding_method
not
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
]:
logger
.
info
(
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
assert
self
.
config
.
decode
.
decoding_method
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
raise
Exception
(
f
"not support:
{
self
.
model_type
}
"
)
def
init_model
(
self
)
->
None
:
if
"deepspeech2"
in
self
.
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
# load model
# model_type: {model_name}_{dataset}
model_name
=
self
.
model_type
[:
self
.
model_type
.
rindex
(
'_'
)]
logger
.
info
(
f
"model name:
{
model_name
}
"
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model
=
model_class
.
from_config
(
self
.
config
)
self
.
model
=
model
self
.
model
.
set_state_dict
(
paddle
.
load
(
self
.
am_model
))
self
.
model
.
eval
()
else
:
raise
Exception
(
f
"not support:
{
self
.
model_type
}
"
)
def
_init_from_path
(
self
,
def
_init_from_path
(
self
,
model_type
:
str
=
None
,
model_type
:
str
=
None
,
am_model
:
Optional
[
os
.
PathLike
]
=
None
,
am_model
:
Optional
[
os
.
PathLike
]
=
None
,
...
@@ -718,8 +780,13 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -718,8 +780,13 @@ class ASRServerExecutor(ASRExecutor):
self
.
model_type
=
model_type
self
.
model_type
=
model_type
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
decode_method
=
decode_method
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self
.
am_predictor_conf
=
am_predictor_conf
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
self
.
task_resource
.
set_task_model
(
model_tag
=
tag
)
self
.
task_resource
.
set_task_model
(
model_tag
=
tag
)
...
@@ -763,62 +830,10 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -763,62 +830,10 @@ class ASRServerExecutor(ASRExecutor):
vocab
=
self
.
config
.
vocab_filepath
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
if
"deepspeech2"
in
model_type
:
self
.
update_config
()
with
UpdateConfig
(
self
.
config
):
# download lm
# AM predictor
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
self
.
init_model
()
MODEL_HOME
,
'language_model'
,
self
.
config
.
decode
.
lang_model_path
)
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
with
UpdateConfig
(
self
.
config
):
logger
.
info
(
"start to create the stream conformer asr engine"
)
# update the decoding method
if
decode_method
:
self
.
config
.
decode
.
decoding_method
=
decode_method
# update num_decoding_left_chunks
if
num_decoding_left_chunks
:
assert
num_decoding_left_chunks
==
-
1
or
num_decoding_left_chunks
>=
0
,
f
"num_decoding_left_chunks should be -1 or >=0"
self
.
config
.
decode
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if
self
.
config
.
decode
.
decoding_method
not
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
]:
logger
.
info
(
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
assert
self
.
config
.
decode
.
decoding_method
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
# load model
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
logger
.
info
(
f
"model name:
{
model_name
}
"
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model
=
model_class
.
from_config
(
self
.
config
)
self
.
model
=
model
self
.
model
.
set_state_dict
(
paddle
.
load
(
self
.
am_model
))
self
.
model
.
eval
()
else
:
raise
Exception
(
f
"not support:
{
model_type
}
"
)
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
return
True
return
True
...
@@ -835,6 +850,22 @@ class ASREngine(BaseEngine):
...
@@ -835,6 +850,22 @@ class ASREngine(BaseEngine):
super
(
ASREngine
,
self
).
__init__
()
super
(
ASREngine
,
self
).
__init__
()
logger
.
info
(
"create the online asr engine resource instance"
)
logger
.
info
(
"create the online asr engine resource instance"
)
def
init_model
(
self
)
->
bool
:
if
not
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model_type
,
am_model
=
self
.
config
.
am_model
,
am_params
=
self
.
config
.
am_params
,
lang
=
self
.
config
.
lang
,
sample_rate
=
self
.
config
.
sample_rate
,
cfg_path
=
self
.
config
.
cfg_path
,
decode_method
=
self
.
config
.
decode_method
,
num_decoding_left_chunks
=
self
.
config
.
num_decoding_left_chunks
,
am_predictor_conf
=
self
.
config
.
am_predictor_conf
):
return
False
return
True
def
init
(
self
,
config
:
dict
)
->
bool
:
def
init
(
self
,
config
:
dict
)
->
bool
:
"""init engine resource
"""init engine resource
...
@@ -860,16 +891,7 @@ class ASREngine(BaseEngine):
...
@@ -860,16 +891,7 @@ class ASREngine(BaseEngine):
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
executor
.
_init_from_path
(
if
not
self
.
init_model
():
model_type
=
self
.
config
.
model_type
,
am_model
=
self
.
config
.
am_model
,
am_params
=
self
.
config
.
am_params
,
lang
=
self
.
config
.
lang
,
sample_rate
=
self
.
config
.
sample_rate
,
cfg_path
=
self
.
config
.
cfg_path
,
decode_method
=
self
.
config
.
decode_method
,
num_decoding_left_chunks
=
self
.
config
.
num_decoding_left_chunks
,
am_predictor_conf
=
self
.
config
.
am_predictor_conf
):
logger
.
error
(
logger
.
error
(
"Init the ASR server occurs error, please check the server configuration yaml"
"Init the ASR server occurs error, please check the server configuration yaml"
)
)
...
...
paddlespeech/server/engine/engine_factory.py
浏览文件 @
b9e3e493
...
@@ -26,7 +26,10 @@ class EngineFactory(object):
...
@@ -26,7 +26,10 @@ class EngineFactory(object):
from
paddlespeech.server.engine.asr.python.asr_engine
import
ASREngine
from
paddlespeech.server.engine.asr.python.asr_engine
import
ASREngine
return
ASREngine
()
return
ASREngine
()
elif
engine_name
==
'asr'
and
engine_type
==
'online'
:
elif
engine_name
==
'asr'
and
engine_type
==
'online'
:
from
paddlespeech.server.engine.asr.online.asr_engine
import
ASREngine
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
ASREngine
return
ASREngine
()
elif
engine_name
==
'asr'
and
engine_type
==
'online-onnx'
:
from
paddlespeech.server.engine.asr.online.onnx.asr_engine
import
ASREngine
return
ASREngine
()
return
ASREngine
()
elif
engine_name
==
'tts'
and
engine_type
==
'inference'
:
elif
engine_name
==
'tts'
and
engine_type
==
'inference'
:
from
paddlespeech.server.engine.tts.paddleinference.tts_engine
import
TTSEngine
from
paddlespeech.server.engine.tts.paddleinference.tts_engine
import
TTSEngine
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录