Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5acb0b52
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5acb0b52
编写于
4月 18, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the websocket chunk edge bug, test=doc
上级
05a8a4b5
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
66 addition
and
56 deletion
+66
-56
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+66
-55
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+0
-1
未找到文件。
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
5acb0b52
...
@@ -60,9 +60,9 @@ pretrained_models = {
...
@@ -60,9 +60,9 @@ pretrained_models = {
},
},
"conformer2online_aishell-zh-16k"
:
{
"conformer2online_aishell-zh-16k"
:
{
'url'
:
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.
1
.model.tar.gz'
,
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.
3
.model.tar.gz'
,
'md5'
:
'md5'
:
'
b450d5dfaea0ac227c595ce58d18b637
'
,
'
0ac93d390552336f2a906aec9e33c5fa
'
,
'cfg_path'
:
'cfg_path'
:
'model.yaml'
,
'model.yaml'
,
'ckpt_path'
:
'ckpt_path'
:
...
@@ -78,12 +78,19 @@ pretrained_models = {
...
@@ -78,12 +78,19 @@ pretrained_models = {
},
},
}
}
# ASR server connection process class
# ASR server connection process class
class
PaddleASRConnectionHanddler
:
class
PaddleASRConnectionHanddler
:
def
__init__
(
self
,
asr_engine
):
def
__init__
(
self
,
asr_engine
):
"""Init a Paddle ASR Connection Handler instance
Args:
asr_engine (ASREngine): the global asr engine
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
"create an paddle asr connection handler to process the websocket connection"
)
logger
.
info
(
"create an paddle asr connection handler to process the websocket connection"
)
self
.
config
=
asr_engine
.
config
self
.
config
=
asr_engine
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model
=
asr_engine
.
executor
.
model
self
.
model
=
asr_engine
.
executor
.
model
...
@@ -105,14 +112,16 @@ class PaddleASRConnectionHanddler:
...
@@ -105,14 +112,16 @@ class PaddleASRConnectionHanddler:
# tokens to text
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
# ctc decoding
# ctc decoding
config
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
ctc_decode_config
)
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
ctc_decode_config
)
# extract f
bank
# extract f
eat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window samples length and frame shift samples length
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
...
@@ -141,15 +150,17 @@ class PaddleASRConnectionHanddler:
...
@@ -141,15 +150,17 @@ class PaddleASRConnectionHanddler:
return
0
return
0
# fbank
# fbank
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
self
.
cached_feat
=
x_chunk
else
:
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
assert
(
len
(
x_chunk
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
([
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
# set the feat device
# set the feat device
if
self
.
device
is
None
:
if
self
.
device
is
None
:
...
@@ -218,15 +229,21 @@ class PaddleASRConnectionHanddler:
...
@@ -218,15 +229,21 @@ class PaddleASRConnectionHanddler:
return
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
logger
.
info
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
return
None
,
None
if
num_frames
<
context
:
if
num_frames
<
context
:
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
return
None
,
None
logger
.
info
(
"start to do model forward"
)
logger
.
info
(
"start to do model forward"
)
...
@@ -258,14 +275,11 @@ class PaddleASRConnectionHanddler:
...
@@ -258,14 +275,11 @@ class PaddleASRConnectionHanddler:
# update the offset
# update the offset
self
.
offset
+=
y
.
shape
[
1
]
self
.
offset
+=
y
.
shape
[
1
]
logger
.
info
(
f
"output size:
{
len
(
outputs
)
}
"
)
ys
=
paddle
.
cat
(
outputs
,
1
)
ys
=
paddle
.
cat
(
outputs
,
1
)
if
self
.
encoder_out
is
None
:
if
self
.
encoder_out
is
None
:
self
.
encoder_out
=
ys
self
.
encoder_out
=
ys
else
:
else
:
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# get the ctc probs
# get the ctc probs
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
...
@@ -274,18 +288,17 @@ class PaddleASRConnectionHanddler:
...
@@ -274,18 +288,17 @@ class PaddleASRConnectionHanddler:
self
.
searcher
.
search
(
None
,
ctc_probs
,
self
.
cached_feat
.
place
)
self
.
searcher
.
search
(
None
,
ctc_probs
,
self
.
cached_feat
.
place
)
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
# remove the processed feat
if
end
==
num_frames
:
self
.
cached_feat
=
None
else
:
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
end
>=
cached_feature_num
assert
end
>=
cached_feature_num
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
cached_feature_num
:,:].
unsqueeze
(
0
)
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
assert
len
(
self
.
cached_feat
.
shape
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
cached_feature_num
:,
:].
unsqueeze
(
0
)
assert
len
(
self
.
cached_feat
.
shape
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
# ys for rescoring
logger
.
info
(
# return ys, masks
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
def
update_result
(
self
):
def
update_result
(
self
):
logger
.
info
(
"update the final result"
)
logger
.
info
(
"update the final result"
)
...
@@ -363,8 +376,6 @@ class PaddleASRConnectionHanddler:
...
@@ -363,8 +376,6 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
f
"best index:
{
best_index
}
"
)
logger
.
info
(
f
"best index:
{
best_index
}
"
)
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
self
.
update_result
()
self
.
update_result
()
# return hyps[best_index][0]
class
ASRServerExecutor
(
ASRExecutor
):
class
ASRServerExecutor
(
ASRExecutor
):
...
@@ -409,9 +420,9 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -409,9 +420,9 @@ class ASRServerExecutor(ASRExecutor):
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
res_path
=
res_path
self
.
res_path
=
res_path
self
.
cfg_path
=
"/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
#
self.cfg_path = os.path.join(res_path,
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
#
pretrained_models[tag]['cfg_path'])
pretrained_models
[
tag
][
'cfg_path'
])
self
.
am_model
=
os
.
path
.
join
(
res_path
,
self
.
am_model
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'model'
])
pretrained_models
[
tag
][
'model'
])
...
...
paddlespeech/server/ws/asr_socket.py
浏览文件 @
5acb0b52
...
@@ -96,7 +96,6 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -96,7 +96,6 @@ async def websocket_endpoint(websocket: WebSocket):
asr_results
=
connection_handler
.
get_result
()
asr_results
=
connection_handler
.
get_result
()
resp
=
{
'asr_results'
:
asr_results
}
resp
=
{
'asr_results'
:
asr_results
}
print
(
"
\n
"
)
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
except
WebSocketDisconnect
:
except
WebSocketDisconnect
:
pass
pass
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录