Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
05a8a4b5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05a8a4b5
编写于
4月 18, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add connection stability, test=doc
上级
68731c61
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
121 addition
and
35 deletion
+121
-35
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+98
-11
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+10
-0
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+13
-24
未找到文件。
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
05a8a4b5
...
...
@@ -83,8 +83,10 @@ pretrained_models = {
class
PaddleASRConnectionHanddler
:
def
__init__
(
self
,
asr_engine
):
super
().
__init__
()
logger
.
info
(
"create an paddle asr connection handler to process the websocket connection"
)
self
.
config
=
asr_engine
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model
=
asr_engine
.
executor
.
model
self
.
asr_engine
=
asr_engine
self
.
init
()
...
...
@@ -149,6 +151,10 @@ class PaddleASRConnectionHanddler:
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
([
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
num_frames
=
x_chunk
.
shape
[
1
]
self
.
num_frames
+=
num_frames
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
...
...
@@ -165,16 +171,17 @@ class PaddleASRConnectionHanddler:
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
s_
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
num_samples
=
0
self
.
device
=
None
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result
=
[
]
self
.
result
_transcripts
=
[
''
]
def
decode
(
self
,
is_finished
=
False
):
if
"deepspeech2online"
in
self
.
model_type
:
...
...
@@ -187,7 +194,6 @@ class PaddleASRConnectionHanddler:
self
.
advance_decoding
(
is_finished
)
self
.
update_result
()
return
self
.
result_transcripts
[
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
...
...
@@ -203,16 +209,26 @@ class PaddleASRConnectionHanddler:
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
# processed chunk feature cached for next chunk
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
if
num_frames
<
context
:
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
logger
.
info
(
"start to do model forward"
)
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
outputs
=
[]
...
...
@@ -242,14 +258,18 @@ class PaddleASRConnectionHanddler:
# update the offset
self
.
offset
+=
y
.
shape
[
1
]
logger
.
info
(
f
"output size:
{
len
(
outputs
)
}
"
)
ys
=
paddle
.
cat
(
outputs
,
1
)
masks
=
paddle
.
ones
([
1
,
ys
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
masks
=
masks
.
unsqueeze
(
1
)
if
self
.
encoder_out
is
None
:
self
.
encoder_out
=
ys
else
:
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# get the ctc probs
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# self.searcher.search(xs, ctc_probs, xs.place)
self
.
searcher
.
search
(
None
,
ctc_probs
,
self
.
cached_feat
.
place
)
...
...
@@ -260,7 +280,8 @@ class PaddleASRConnectionHanddler:
self
.
cached_feat
=
None
else
:
assert
self
.
cached_feat
.
shape
[
0
]
==
1
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
:,:].
unsqueeze
(
0
)
assert
end
>=
cached_feature_num
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
cached_feature_num
:,:].
unsqueeze
(
0
)
assert
len
(
self
.
cached_feat
.
shape
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
# ys for rescoring
...
...
@@ -274,9 +295,75 @@ class PaddleASRConnectionHanddler:
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
get_result
(
self
):
if
len
(
self
.
result_transcripts
)
>
0
:
return
self
.
result_transcripts
[
0
]
else
:
return
''
def
rescoring
(
self
):
pass
logger
.
info
(
"rescoring the final result"
)
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
return
self
.
searcher
.
finalize_search
()
self
.
update_result
()
beam_size
=
self
.
ctc_decode_config
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
if
hyps
is
None
or
len
(
hyps
)
==
0
:
return
# assert len(hyps) == beam_size
paddle
.
save
(
self
.
encoder_out
,
"encoder.out"
)
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
self
.
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
self
.
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
model
.
sos
,
self
.
model
.
eos
,
self
.
model
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
self
.
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
ctc_decode_config
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
# update the one best result
logger
.
info
(
f
"best index:
{
best_index
}
"
)
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
self
.
update_result
()
# return hyps[best_index][0]
...
...
@@ -552,7 +639,7 @@ class ASRServerExecutor(ASRExecutor):
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
num_frames
=
xs
.
shape
[
1
]
...
...
paddlespeech/server/engine/asr/online/ctc_search.py
浏览文件 @
05a8a4b5
...
...
@@ -110,6 +110,11 @@ class CTCPrefixBeamSearch:
return
[
self
.
hyps
[
0
][
0
]]
def
get_hyps
(
self
):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return
self
.
hyps
def
reset
(
self
):
...
...
@@ -117,3 +122,8 @@ class CTCPrefixBeamSearch:
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
def
finalize_search
(
self
):
"""do nothing in ctc_prefix_beam_search
"""
pass
paddlespeech/server/ws/asr_socket.py
浏览文件 @
05a8a4b5
...
...
@@ -13,16 +13,15 @@
# limitations under the License.
import
json
import
numpy
as
np
from
fastapi
import
APIRouter
from
fastapi
import
WebSocket
from
fastapi
import
WebSocketDisconnect
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
paddlespeech.server.engine.asr.online.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.engine.engine_pool
import
get_engine_pool
from
paddlespeech.server.utils.buffer
import
ChunkBuffer
from
paddlespeech.server.utils.vad
import
VADAudio
from
paddlespeech.server.engine.asr.online.asr_engine
import
PaddleASRConnectionHanddler
router
=
APIRouter
()
...
...
@@ -73,13 +72,17 @@ async def websocket_endpoint(websocket: WebSocket):
connection_handler
=
PaddleASRConnectionHanddler
(
asr_engine
)
await
websocket
.
send_json
(
resp
)
elif
message
[
'signal'
]
==
'end'
:
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
# reset single engine for an new connection
asr_results
=
connection_handler
.
decode
(
is_finished
=
True
)
connection_handler
.
decode
(
is_finished
=
True
)
connection_handler
.
rescoring
()
asr_results
=
connection_handler
.
get_result
()
connection_handler
.
reset
()
asr_engine
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'asr_results'
:
asr_results
}
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'asr_results'
:
asr_results
}
await
websocket
.
send_json
(
resp
)
break
else
:
...
...
@@ -87,25 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
await
websocket
.
send_json
(
resp
)
elif
"bytes"
in
message
:
message
=
message
[
"bytes"
]
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_results
=
""
connection_handler
.
extract_feat
(
message
)
asr_results
=
connection_handler
.
decode
(
is_finished
=
False
)
# connection_handler.
# frames = chunk_buffer.frame_generator(message)
# for frame in frames:
# # get the pcm data from the bytes
# samples = np.frombuffer(frame.bytes, dtype=np.int16)
# sample_rate = asr_engine.config.sample_rate
# x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
# sample_rate)
# asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess()
connection_handler
.
decode
(
is_finished
=
False
)
asr_results
=
connection_handler
.
get_result
()
# # connection accept the sample data frame by frame
# asr_results = asr_engine.postprocess()
resp
=
{
'asr_results'
:
asr_results
}
print
(
"
\n
"
)
await
websocket
.
send_json
(
resp
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录