Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
af484fc9
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
af484fc9
编写于
4月 14, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
convert websockert results to str from bytest, test=doc
上级
23a65341
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
48 addition
and
15 deletion
+48
-15
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+16
-7
paddlespeech/server/tests/asr/online/websocket_client.py
paddlespeech/server/tests/asr/online/websocket_client.py
+32
-8
未找到文件。
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
af484fc9
...
...
@@ -35,9 +35,9 @@ __all__ = ['ASREngine']
pretrained_models
=
{
"deepspeech2online_aishell-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.
1.1
.model.tar.gz'
,
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.
2.0
.model.tar.gz'
,
'md5'
:
'
d5e076217cf60486519f72c217d21b9b
'
,
'
23e16c69730a1cb5d735c98c83c21e16
'
,
'cfg_path'
:
'model.yaml'
,
'ckpt_path'
:
...
...
@@ -75,6 +75,7 @@ class ASRServerExecutor(ASRExecutor):
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
res_path
=
res_path
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
...
...
@@ -85,9 +86,6 @@ class ASRServerExecutor(ASRExecutor):
self
.
am_params
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'params'
])
logger
.
info
(
res_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
...
...
@@ -95,6 +93,10 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
...
...
@@ -112,15 +114,20 @@ class ASRServerExecutor(ASRExecutor):
lm_url
=
pretrained_models
[
tag
][
'lm_url'
]
lm_md5
=
pretrained_models
[
tag
][
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
raise
Exception
(
"wrong type"
)
# 开发 conformer 的流式模型
logger
.
info
(
"start to create the stream conformer asr engine"
)
# 复用cli里面的代码
else
:
raise
Exception
(
"wrong type"
)
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
...
...
@@ -128,6 +135,7 @@ class ASRServerExecutor(ASRExecutor):
predictor_conf
=
self
.
am_predictor_conf
)
# decoder
logger
.
info
(
"ASR engine start to create the ctc decoder instance"
)
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
config
.
rnn_layer_size
*
2
,
...
...
@@ -138,6 +146,7 @@ class ASRServerExecutor(ASRExecutor):
grad_norm_type
=
self
.
config
.
get
(
'ctc_grad_norm_type'
,
None
))
# init decoder
logger
.
info
(
"ASR engine start to init the ctc decoder"
)
cfg
=
self
.
config
.
decode
decode_batch_size
=
1
# for online
self
.
decoder
.
init_decoder
(
...
...
@@ -215,7 +224,6 @@ class ASRServerExecutor(ASRExecutor):
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
return
trans_best
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
...
...
@@ -273,6 +281,7 @@ class ASREngine(BaseEngine):
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
logger
.
info
(
"create the online asr engine instache"
)
def
init
(
self
,
config
:
dict
)
->
bool
:
"""init engine resource
...
...
paddlespeech/server/tests/asr/online/websocket_client.py
浏览文件 @
af484fc9
...
...
@@ -15,8 +15,10 @@
# -*- coding: UTF-8 -*-
import
argparse
import
asyncio
import
codecs
import
json
import
logging
import
os
import
numpy
as
np
import
soundfile
...
...
@@ -54,12 +56,11 @@ class ASRAudioHandler:
async
def
run
(
self
,
wavfile_path
:
str
):
logging
.
info
(
"send a message to the server"
)
# 读取音频
# self.read_wave()
#
发送 websocket 的 handshake 协议头
#
send websocket handshake protocal
async
with
websockets
.
connect
(
self
.
url
)
as
ws
:
# server
端已经接收到 handshake 协议头
#
发送开始指令
# server
has already received handshake protocal
#
client start to send the command
audio_info
=
json
.
dumps
(
{
"name"
:
"test.wav"
,
...
...
@@ -77,8 +78,9 @@ class ASRAudioHandler:
for
chunk_data
in
self
.
read_wave
(
wavfile_path
):
await
ws
.
send
(
chunk_data
.
tobytes
())
msg
=
await
ws
.
recv
()
msg
=
json
.
loads
(
msg
)
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
result
=
msg
# finished
audio_info
=
json
.
dumps
(
{
...
...
@@ -91,16 +93,36 @@ class ASRAudioHandler:
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
# decode the bytes to str
msg
=
json
.
loads
(
msg
)
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
return
result
def
main
(
args
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
info
(
"asr websocket client start"
)
handler
=
ASRAudioHandler
(
"127.0.0.1"
,
809
1
)
handler
=
ASRAudioHandler
(
"127.0.0.1"
,
809
0
)
loop
=
asyncio
.
get_event_loop
()
loop
.
run_until_complete
(
handler
.
run
(
args
.
wavfile
))
logging
.
info
(
"asr websocket client finished"
)
# support to process single audio file
if
args
.
wavfile
and
os
.
path
.
exists
(
args
.
wavfile
):
logging
.
info
(
f
"start to process the wavscp:
{
args
.
wavfile
}
"
)
result
=
loop
.
run_until_complete
(
handler
.
run
(
args
.
wavfile
))
result
=
result
[
"asr_results"
]
logging
.
info
(
f
"asr websocket client finished :
{
result
}
"
)
# support to process batch audios from wav.scp
if
args
.
wavscp
and
os
.
path
.
exists
(
args
.
wavscp
):
logging
.
info
(
f
"start to process the wavscp:
{
args
.
wavscp
}
"
)
with
codecs
.
open
(
args
.
wavscp
,
'r'
,
encoding
=
'utf-8'
)
as
f
,
\
codecs
.
open
(
"result.txt"
,
'w'
,
encoding
=
'utf-8'
)
as
w
:
for
line
in
f
:
utt_name
,
utt_path
=
line
.
strip
().
split
()
result
=
loop
.
run_until_complete
(
handler
.
run
(
utt_path
))
result
=
result
[
"asr_results"
]
w
.
write
(
f
"
{
utt_name
}
{
result
}
\n
"
)
if
__name__
==
"__main__"
:
...
...
@@ -110,6 +132,8 @@ if __name__ == "__main__":
action
=
"store"
,
help
=
"wav file path "
,
default
=
"./16_audio.wav"
)
parser
.
add_argument
(
"--wavscp"
,
type
=
str
,
default
=
None
,
help
=
"The batch audios dict text"
)
args
=
parser
.
parse_args
()
main
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录