Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
05a8a4b5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05a8a4b5
编写于
4月 18, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add connection stability, test=doc
上级
68731c61
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
121 addition
and
35 deletion
+121
-35
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+98
-11
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+10
-0
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+13
-24
未找到文件。
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
05a8a4b5
...
@@ -83,8 +83,10 @@ pretrained_models = {
...
@@ -83,8 +83,10 @@ pretrained_models = {
class
PaddleASRConnectionHanddler
:
class
PaddleASRConnectionHanddler
:
def
__init__
(
self
,
asr_engine
):
def
__init__
(
self
,
asr_engine
):
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
"create an paddle asr connection handler to process the websocket connection"
)
self
.
config
=
asr_engine
.
config
self
.
config
=
asr_engine
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model
=
asr_engine
.
executor
.
model
self
.
asr_engine
=
asr_engine
self
.
asr_engine
=
asr_engine
self
.
init
()
self
.
init
()
...
@@ -149,6 +151,10 @@ class PaddleASRConnectionHanddler:
...
@@ -149,6 +151,10 @@ class PaddleASRConnectionHanddler:
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
([
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
self
.
cached_feat
=
paddle
.
concat
([
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
num_frames
=
x_chunk
.
shape
[
1
]
num_frames
=
x_chunk
.
shape
[
1
]
self
.
num_frames
+=
num_frames
self
.
num_frames
+=
num_frames
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
...
@@ -165,16 +171,17 @@ class PaddleASRConnectionHanddler:
...
@@ -165,16 +171,17 @@ class PaddleASRConnectionHanddler:
self
.
subsampling_cache
=
None
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
s_
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
offset
=
0
self
.
num_samples
=
0
self
.
num_samples
=
0
self
.
device
=
None
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
global_frame_offset
=
0
self
.
result
=
[
]
self
.
result
_transcripts
=
[
''
]
def
decode
(
self
,
is_finished
=
False
):
def
decode
(
self
,
is_finished
=
False
):
if
"deepspeech2online"
in
self
.
model_type
:
if
"deepspeech2online"
in
self
.
model_type
:
...
@@ -187,7 +194,6 @@ class PaddleASRConnectionHanddler:
...
@@ -187,7 +194,6 @@ class PaddleASRConnectionHanddler:
self
.
advance_decoding
(
is_finished
)
self
.
advance_decoding
(
is_finished
)
self
.
update_result
()
self
.
update_result
()
return
self
.
result_transcripts
[
0
]
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
exception
(
e
)
logger
.
exception
(
e
)
else
:
else
:
...
@@ -203,14 +209,24 @@ class PaddleASRConnectionHanddler:
...
@@ -203,14 +209,24 @@ class PaddleASRConnectionHanddler:
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
# processed chunk feature cached for next chunk
# decoding window for model
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
if
num_frames
<
context
:
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
return
None
,
None
logger
.
info
(
"start to do model forward"
)
logger
.
info
(
"start to do model forward"
)
...
@@ -242,14 +258,18 @@ class PaddleASRConnectionHanddler:
...
@@ -242,14 +258,18 @@ class PaddleASRConnectionHanddler:
# update the offset
# update the offset
self
.
offset
+=
y
.
shape
[
1
]
self
.
offset
+=
y
.
shape
[
1
]
logger
.
info
(
f
"output size:
{
len
(
outputs
)
}
"
)
ys
=
paddle
.
cat
(
outputs
,
1
)
ys
=
paddle
.
cat
(
outputs
,
1
)
masks
=
paddle
.
ones
([
1
,
ys
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
if
self
.
encoder_out
is
None
:
masks
=
masks
.
unsqueeze
(
1
)
self
.
encoder_out
=
ys
else
:
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# get the ctc probs
# get the ctc probs
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# self.searcher.search(xs, ctc_probs, xs.place)
self
.
searcher
.
search
(
None
,
ctc_probs
,
self
.
cached_feat
.
place
)
self
.
searcher
.
search
(
None
,
ctc_probs
,
self
.
cached_feat
.
place
)
...
@@ -260,7 +280,8 @@ class PaddleASRConnectionHanddler:
...
@@ -260,7 +280,8 @@ class PaddleASRConnectionHanddler:
self
.
cached_feat
=
None
self
.
cached_feat
=
None
else
:
else
:
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
self
.
cached_feat
.
shape
[
0
]
==
1
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
:,:].
unsqueeze
(
0
)
assert
end
>=
cached_feature_num
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
cached_feature_num
:,:].
unsqueeze
(
0
)
assert
len
(
self
.
cached_feat
.
shape
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
assert
len
(
self
.
cached_feat
.
shape
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
# ys for rescoring
# ys for rescoring
...
@@ -274,9 +295,75 @@ class PaddleASRConnectionHanddler:
...
@@ -274,9 +295,75 @@ class PaddleASRConnectionHanddler:
]
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
get_result
(
self
):
if
len
(
self
.
result_transcripts
)
>
0
:
return
self
.
result_transcripts
[
0
]
else
:
return
''
def
rescoring
(
self
):
def
rescoring
(
self
):
pass
logger
.
info
(
"rescoring the final result"
)
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
return
self
.
searcher
.
finalize_search
()
self
.
update_result
()
beam_size
=
self
.
ctc_decode_config
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
if
hyps
is
None
or
len
(
hyps
)
==
0
:
return
# assert len(hyps) == beam_size
paddle
.
save
(
self
.
encoder_out
,
"encoder.out"
)
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
self
.
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
self
.
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
model
.
sos
,
self
.
model
.
eos
,
self
.
model
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
self
.
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
ctc_decode_config
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
# update the one best result
logger
.
info
(
f
"best index:
{
best_index
}
"
)
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
self
.
update_result
()
# return hyps[best_index][0]
...
...
paddlespeech/server/engine/asr/online/ctc_search.py
浏览文件 @
05a8a4b5
...
@@ -110,6 +110,11 @@ class CTCPrefixBeamSearch:
...
@@ -110,6 +110,11 @@ class CTCPrefixBeamSearch:
return
[
self
.
hyps
[
0
][
0
]]
return
[
self
.
hyps
[
0
][
0
]]
def
get_hyps
(
self
):
def
get_hyps
(
self
):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return
self
.
hyps
return
self
.
hyps
def
reset
(
self
):
def
reset
(
self
):
...
@@ -117,3 +122,8 @@ class CTCPrefixBeamSearch:
...
@@ -117,3 +122,8 @@ class CTCPrefixBeamSearch:
"""
"""
self
.
cur_hyps
=
None
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
hyps
=
None
def
finalize_search
(
self
):
"""do nothing in ctc_prefix_beam_search
"""
pass
paddlespeech/server/ws/asr_socket.py
浏览文件 @
05a8a4b5
...
@@ -13,16 +13,15 @@
...
@@ -13,16 +13,15 @@
# limitations under the License.
# limitations under the License.
import
json
import
json
import
numpy
as
np
from
fastapi
import
APIRouter
from
fastapi
import
APIRouter
from
fastapi
import
WebSocket
from
fastapi
import
WebSocket
from
fastapi
import
WebSocketDisconnect
from
fastapi
import
WebSocketDisconnect
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
paddlespeech.server.engine.asr.online.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.engine.engine_pool
import
get_engine_pool
from
paddlespeech.server.engine.engine_pool
import
get_engine_pool
from
paddlespeech.server.utils.buffer
import
ChunkBuffer
from
paddlespeech.server.utils.buffer
import
ChunkBuffer
from
paddlespeech.server.utils.vad
import
VADAudio
from
paddlespeech.server.utils.vad
import
VADAudio
from
paddlespeech.server.engine.asr.online.asr_engine
import
PaddleASRConnectionHanddler
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -73,13 +72,17 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -73,13 +72,17 @@ async def websocket_endpoint(websocket: WebSocket):
connection_handler
=
PaddleASRConnectionHanddler
(
asr_engine
)
connection_handler
=
PaddleASRConnectionHanddler
(
asr_engine
)
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
elif
message
[
'signal'
]
==
'end'
:
elif
message
[
'signal'
]
==
'end'
:
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
# reset single engine for an new connection
# reset single engine for an new connection
asr_results
=
connection_handler
.
decode
(
is_finished
=
True
)
connection_handler
.
decode
(
is_finished
=
True
)
connection_handler
.
rescoring
()
asr_results
=
connection_handler
.
get_result
()
connection_handler
.
reset
()
connection_handler
.
reset
()
asr_engine
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'asr_results'
:
asr_results
}
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'asr_results'
:
asr_results
}
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
break
break
else
:
else
:
...
@@ -87,25 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -87,25 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
elif
"bytes"
in
message
:
elif
"bytes"
in
message
:
message
=
message
[
"bytes"
]
message
=
message
[
"bytes"
]
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_results
=
""
connection_handler
.
extract_feat
(
message
)
asr_results
=
connection_handler
.
decode
(
is_finished
=
False
)
# connection_handler.
# frames = chunk_buffer.frame_generator(message)
# for frame in frames:
# # get the pcm data from the bytes
# samples = np.frombuffer(frame.bytes, dtype=np.int16)
# sample_rate = asr_engine.config.sample_rate
# x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
# sample_rate)
# asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess()
# # connection accept the sample data frame by frame
connection_handler
.
extract_feat
(
message
)
connection_handler
.
decode
(
is_finished
=
False
)
asr_results
=
connection_handler
.
get_result
()
# asr_results = asr_engine.postprocess()
resp
=
{
'asr_results'
:
asr_results
}
resp
=
{
'asr_results'
:
asr_results
}
print
(
"
\n
"
)
print
(
"
\n
"
)
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录