Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
69a6da4c
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
69a6da4c
编写于
6月 07, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ctc endpoint work
上级
8f9b7bba
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
98 addition
and
22 deletion
+98
-22
demos/streaming_asr_server/conf/application.yaml
demos/streaming_asr_server/conf/application.yaml
+2
-0
demos/streaming_asr_server/conf/ws_conformer_application.yaml
...s/streaming_asr_server/conf/ws_conformer_application.yaml
+3
-0
demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml
...asr_server/conf/ws_conformer_wenetspeech_application.yaml
+2
-0
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+1
-0
paddlespeech/server/conf/ws_conformer_application.yaml
paddlespeech/server/conf/ws_conformer_application.yaml
+2
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+46
-15
paddlespeech/server/engine/asr/online/ctc_endpoint.py
paddlespeech/server/engine/asr/online/ctc_endpoint.py
+14
-3
paddlespeech/server/ws/asr_api.py
paddlespeech/server/ws/asr_api.py
+28
-4
未找到文件。
demos/streaming_asr_server/conf/application.yaml
浏览文件 @
69a6da4c
...
...
@@ -31,6 +31,8 @@ asr_online:
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
demos/streaming_asr_server/conf/ws_conformer_application.yaml
浏览文件 @
69a6da4c
...
...
@@ -30,6 +30,9 @@ asr_online:
decode_method
:
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml
浏览文件 @
69a6da4c
...
...
@@ -31,6 +31,8 @@ asr_online:
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
69a6da4c
...
...
@@ -29,6 +29,7 @@ asr_online:
cfg_path
:
decode_method
:
force_yes
:
True
device
:
# cpu or gpu:id
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
...
...
paddlespeech/server/conf/ws_conformer_application.yaml
浏览文件 @
69a6da4c
...
...
@@ -30,6 +30,8 @@ asr_online:
decode_method
:
force_yes
:
True
device
:
# cpu or gpu:id
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
...
...
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
69a6da4c
...
...
@@ -55,7 +55,7 @@ class PaddleASRConnectionHanddler:
self
.
config
=
asr_engine
.
config
# server config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
asr_engine
=
asr_engine
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
...
...
@@ -76,11 +76,13 @@ class PaddleASRConnectionHanddler:
self
.
frame_shift_in_ms
=
int
(
self
.
n_shift
/
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
*
1000
)
self
.
continuous_decoding
=
self
.
config
.
get
(
"continuous_decoding"
,
False
)
self
.
init_decoder
()
self
.
reset
()
def
init_decoder
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
assert
self
.
continuous_decoding
is
False
,
"ds2 model not support endpoint"
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
decoder
=
CTCDecoder
(
...
...
@@ -104,6 +106,8 @@ class PaddleASRConnectionHanddler:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
self
.
continuous_decoding
=
self
.
config
.
continuous_decoding
logger
.
info
(
f
"continue decoding:
{
self
.
continuous_decoding
}
"
)
# ctc decoding config
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
...
...
@@ -120,7 +124,8 @@ class PaddleASRConnectionHanddler:
if
"deepspeech2"
in
self
.
model_type
:
return
# feature cache
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
## conformer
...
...
@@ -135,6 +140,19 @@ class PaddleASRConnectionHanddler:
## just for record info
self
.
chunk_num
=
0
# global decoding chunk num, not used
def
output_reset
(
self
):
## outputs
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
## just for record
self
.
hyps
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
def
reset_continuous_decoding
(
self
):
"""
when in continous decoding, reset for next utterance.
...
...
@@ -143,6 +161,7 @@ class PaddleASRConnectionHanddler:
self
.
model_reset
()
self
.
searcher
.
reset
()
self
.
endpointer
.
reset
()
self
.
output_reset
()
def
reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
...
...
@@ -171,24 +190,14 @@ class PaddleASRConnectionHanddler:
# frame step of cur utterance
self
.
num_frames
=
0
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
## endpoint
self
.
endpoint_state
=
False
# True for detect endpoint
## conformer
self
.
model_reset
()
## outputs
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
## just for record
self
.
hyps
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
self
.
output_reset
()
def
extract_feat
(
self
,
samples
:
ByteString
):
logger
.
info
(
"Online ASR extract the feat"
)
...
...
@@ -388,6 +397,9 @@ class PaddleASRConnectionHanddler:
if
"deepspeech"
in
self
.
model_type
:
return
# reset endpiont state
self
.
endpoint_state
=
False
logger
.
info
(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
...
...
@@ -489,6 +501,16 @@ class PaddleASRConnectionHanddler:
# get one best hyps
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
# endpoint
if
not
is_finished
:
def
contain_nonsilence
():
return
len
(
self
.
hyps
)
>
0
and
len
(
self
.
hyps
[
0
])
>
0
decoding_something
=
contain_nonsilence
()
if
self
.
endpointer
.
endpoint_detected
(
ctc_probs
.
numpy
(),
decoding_something
):
self
.
endpoint_state
=
True
logger
.
info
(
f
"Endpoint is detected at
{
self
.
num_frames
}
frame."
)
# advance cache of feat
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
assert
end
>=
cached_feature_num
...
...
@@ -847,6 +869,15 @@ class ASREngine(BaseEngine):
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
def
new_handler
(
self
):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return
PaddleASRConnectionHanddler
(
self
)
def
preprocess
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
...
...
paddlespeech/server/engine/asr/online/ctc_endpoint.py
浏览文件 @
69a6da4c
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
List
import
numpy
as
np
from
paddlespeech.cli.log
import
logger
...
...
@@ -76,14 +77,23 @@ class OnlineCTCEndpoint:
)
and
trailine_silence
>=
rule
.
min_trailing_silence
and
utterance_length
>=
rule
.
min_utterance_length
if
(
ans
):
logger
.
info
(
f
"Endpoint Rule:
{
rule_name
}
activated:
{
decoding_something
}
,
{
trailine_silence
}
,
{
utterance_length
}
"
f
"Endpoint Rule:
{
rule_name
}
activated:
{
rule
}
"
)
return
ans
def
endpoint_detected
(
ctc_log_probs
:
List
[
List
[
float
]]
,
def
endpoint_detected
(
self
,
ctc_log_probs
:
np
.
ndarray
,
decoding_something
:
bool
)
->
bool
:
"""detect endpoint.
Args:
ctc_log_probs (np.ndarray): (T, D)
decoding_something (bool): contain nonsilince.
Returns:
bool: whether endpoint detected.
"""
for
logprob
in
ctc_log_probs
:
blank_prob
=
exp
(
logprob
[
self
.
opts
.
blank_id
])
blank_prob
=
np
.
exp
(
logprob
[
self
.
opts
.
blank
])
self
.
num_frames_decoded
+=
1
if
blank_prob
>
self
.
opts
.
blank_threshold
:
...
...
@@ -96,6 +106,7 @@ class OnlineCTCEndpoint:
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
if
self
.
rule_activated
(
self
.
opts
.
rule1
,
'rule1'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
...
...
paddlespeech/server/ws/asr_api.py
浏览文件 @
69a6da4c
...
...
@@ -38,7 +38,7 @@ async def websocket_endpoint(websocket: WebSocket):
#2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool
=
get_engine_pool
()
asr_
engine
=
engine_pool
[
'asr'
]
asr_
model
=
engine_pool
[
'asr'
]
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request
...
...
@@ -70,7 +70,8 @@ async def websocket_endpoint(websocket: WebSocket):
resp
=
{
"status"
:
"ok"
,
"signal"
:
"server_ready"
}
# do something at begining here
# create the instance to process the audio
connection_handler
=
PaddleASRConnectionHanddler
(
asr_engine
)
#connection_handler = PaddleASRConnectionHanddler(asr_model)
connection_handler
=
asr_model
.
new_handler
()
await
websocket
.
send_json
(
resp
)
elif
message
[
'signal'
]
==
'end'
:
# reset single engine for an new connection
...
...
@@ -100,11 +101,34 @@ async def websocket_endpoint(websocket: WebSocket):
# and decode for the result in this package data
connection_handler
.
extract_feat
(
message
)
connection_handler
.
decode
(
is_finished
=
False
)
if
connection_handler
.
endpoint_state
:
logger
.
info
(
"endpoint: detected and rescoring."
)
connection_handler
.
rescoring
()
word_time_stamp
=
connection_handler
.
get_word_time_stamp
()
asr_results
=
connection_handler
.
get_result
()
# return the current period result
# if the engine create the vad instance, this connection will have many period results
if
connection_handler
.
endpoint_state
:
if
connection_handler
.
continuous_decoding
:
logger
.
info
(
"endpoint: continue decoding"
)
connection_handler
.
reset_continuous_decoding
()
else
:
logger
.
info
(
"endpoint: exit decoding"
)
# ending by endpoint
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'result'
:
asr_results
,
'times'
:
word_time_stamp
}
await
websocket
.
send_json
(
resp
)
break
# return the current partial result
# if the engine create the vad instance, this connection will have many partial results
resp
=
{
'result'
:
asr_results
}
await
websocket
.
send_json
(
resp
)
except
WebSocketDisconnect
as
e
:
logger
.
error
(
e
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录