Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
69a6da4c
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“97b7c913dad91c45aa074d78f35cec6fb68efea1”上不存在“mobile/src/fpga/V1/bias_scale.cpp”
提交
69a6da4c
编写于
6月 07, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ctc endpoint work
上级
8f9b7bba
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
98 addition
and
22 deletion
+98
-22
demos/streaming_asr_server/conf/application.yaml
demos/streaming_asr_server/conf/application.yaml
+2
-0
demos/streaming_asr_server/conf/ws_conformer_application.yaml
...s/streaming_asr_server/conf/ws_conformer_application.yaml
+3
-0
demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml
...asr_server/conf/ws_conformer_wenetspeech_application.yaml
+2
-0
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+1
-0
paddlespeech/server/conf/ws_conformer_application.yaml
paddlespeech/server/conf/ws_conformer_application.yaml
+2
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+46
-15
paddlespeech/server/engine/asr/online/ctc_endpoint.py
paddlespeech/server/engine/asr/online/ctc_endpoint.py
+14
-3
paddlespeech/server/ws/asr_api.py
paddlespeech/server/ws/asr_api.py
+28
-4
未找到文件。
demos/streaming_asr_server/conf/application.yaml
浏览文件 @
69a6da4c
...
@@ -31,6 +31,8 @@ asr_online:
...
@@ -31,6 +31,8 @@ asr_online:
force_yes
:
True
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
switch_ir_optim
:
True
...
...
demos/streaming_asr_server/conf/ws_conformer_application.yaml
浏览文件 @
69a6da4c
...
@@ -30,6 +30,9 @@ asr_online:
...
@@ -30,6 +30,9 @@ asr_online:
decode_method
:
decode_method
:
force_yes
:
True
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
switch_ir_optim
:
True
...
...
demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml
浏览文件 @
69a6da4c
...
@@ -31,6 +31,8 @@ asr_online:
...
@@ -31,6 +31,8 @@ asr_online:
force_yes
:
True
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
device
:
'
cpu'
# cpu or gpu:id
decode_method
:
"
attention_rescoring"
decode_method
:
"
attention_rescoring"
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
switch_ir_optim
:
True
...
...
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
69a6da4c
...
@@ -29,6 +29,7 @@ asr_online:
...
@@ -29,6 +29,7 @@ asr_online:
cfg_path
:
cfg_path
:
decode_method
:
decode_method
:
force_yes
:
True
force_yes
:
True
device
:
# cpu or gpu:id
am_predictor_conf
:
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
device
:
# set 'gpu:id' or 'cpu'
...
...
paddlespeech/server/conf/ws_conformer_application.yaml
浏览文件 @
69a6da4c
...
@@ -30,6 +30,8 @@ asr_online:
...
@@ -30,6 +30,8 @@ asr_online:
decode_method
:
decode_method
:
force_yes
:
True
force_yes
:
True
device
:
# cpu or gpu:id
device
:
# cpu or gpu:id
continuous_decoding
:
True
# enable continue decoding when endpoint detected
am_predictor_conf
:
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
switch_ir_optim
:
True
...
...
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
69a6da4c
...
@@ -76,11 +76,13 @@ class PaddleASRConnectionHanddler:
...
@@ -76,11 +76,13 @@ class PaddleASRConnectionHanddler:
self
.
frame_shift_in_ms
=
int
(
self
.
frame_shift_in_ms
=
int
(
self
.
n_shift
/
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
*
1000
)
self
.
n_shift
/
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
*
1000
)
self
.
continuous_decoding
=
self
.
config
.
get
(
"continuous_decoding"
,
False
)
self
.
init_decoder
()
self
.
init_decoder
()
self
.
reset
()
self
.
reset
()
def
init_decoder
(
self
):
def
init_decoder
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
assert
self
.
continuous_decoding
is
False
,
"ds2 model not support endpoint"
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
decoder
=
CTCDecoder
(
self
.
decoder
=
CTCDecoder
(
...
@@ -104,6 +106,8 @@ class PaddleASRConnectionHanddler:
...
@@ -104,6 +106,8 @@ class PaddleASRConnectionHanddler:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
# acoustic model
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
self
.
model
=
self
.
asr_engine
.
executor
.
model
self
.
continuous_decoding
=
self
.
config
.
continuous_decoding
logger
.
info
(
f
"continue decoding:
{
self
.
continuous_decoding
}
"
)
# ctc decoding config
# ctc decoding config
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
...
@@ -120,7 +124,8 @@ class PaddleASRConnectionHanddler:
...
@@ -120,7 +124,8 @@ class PaddleASRConnectionHanddler:
if
"deepspeech2"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
return
return
# feature cache
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
self
.
cached_feat
=
None
## conformer
## conformer
...
@@ -135,6 +140,19 @@ class PaddleASRConnectionHanddler:
...
@@ -135,6 +140,19 @@ class PaddleASRConnectionHanddler:
## just for record info
## just for record info
self
.
chunk_num
=
0
# global decoding chunk num, not used
self
.
chunk_num
=
0
# global decoding chunk num, not used
def
output_reset
(
self
):
## outputs
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
## just for record
self
.
hyps
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
def
reset_continuous_decoding
(
self
):
def
reset_continuous_decoding
(
self
):
"""
"""
when in continous decoding, reset for next utterance.
when in continous decoding, reset for next utterance.
...
@@ -143,6 +161,7 @@ class PaddleASRConnectionHanddler:
...
@@ -143,6 +161,7 @@ class PaddleASRConnectionHanddler:
self
.
model_reset
()
self
.
model_reset
()
self
.
searcher
.
reset
()
self
.
searcher
.
reset
()
self
.
endpointer
.
reset
()
self
.
endpointer
.
reset
()
self
.
output_reset
()
def
reset
(
self
):
def
reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
...
@@ -171,24 +190,14 @@ class PaddleASRConnectionHanddler:
...
@@ -171,24 +190,14 @@ class PaddleASRConnectionHanddler:
# frame step of cur utterance
# frame step of cur utterance
self
.
num_frames
=
0
self
.
num_frames
=
0
# cache for audio and feat
## endpoint
self
.
remained_wav
=
None
self
.
endpoint_state
=
False
# True for detect endpoint
self
.
cached_feat
=
None
## conformer
## conformer
self
.
model_reset
()
self
.
model_reset
()
## outputs
## outputs
# partial/ending decoding results
self
.
output_reset
()
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
## just for record
self
.
hyps
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
def
extract_feat
(
self
,
samples
:
ByteString
):
def
extract_feat
(
self
,
samples
:
ByteString
):
logger
.
info
(
"Online ASR extract the feat"
)
logger
.
info
(
"Online ASR extract the feat"
)
...
@@ -388,6 +397,9 @@ class PaddleASRConnectionHanddler:
...
@@ -388,6 +397,9 @@ class PaddleASRConnectionHanddler:
if
"deepspeech"
in
self
.
model_type
:
if
"deepspeech"
in
self
.
model_type
:
return
return
# reset endpiont state
self
.
endpoint_state
=
False
logger
.
info
(
logger
.
info
(
"Conformer/Transformer: start to decode with advanced_decoding method"
"Conformer/Transformer: start to decode with advanced_decoding method"
)
)
...
@@ -489,6 +501,16 @@ class PaddleASRConnectionHanddler:
...
@@ -489,6 +501,16 @@ class PaddleASRConnectionHanddler:
# get one best hyps
# get one best hyps
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
# endpoint
if
not
is_finished
:
def
contain_nonsilence
():
return
len
(
self
.
hyps
)
>
0
and
len
(
self
.
hyps
[
0
])
>
0
decoding_something
=
contain_nonsilence
()
if
self
.
endpointer
.
endpoint_detected
(
ctc_probs
.
numpy
(),
decoding_something
):
self
.
endpoint_state
=
True
logger
.
info
(
f
"Endpoint is detected at
{
self
.
num_frames
}
frame."
)
# advance cache of feat
# advance cache of feat
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
assert
end
>=
cached_feature_num
assert
end
>=
cached_feature_num
...
@@ -847,6 +869,15 @@ class ASREngine(BaseEngine):
...
@@ -847,6 +869,15 @@ class ASREngine(BaseEngine):
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
return
True
def
new_handler
(
self
):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return
PaddleASRConnectionHanddler
(
self
)
def
preprocess
(
self
,
*
args
,
**
kwargs
):
def
preprocess
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
raise
NotImplementedError
(
"Online not using this."
)
...
...
paddlespeech/server/engine/asr/online/ctc_endpoint.py
浏览文件 @
69a6da4c
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
from
typing
import
List
import
numpy
as
np
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
...
@@ -76,14 +77,23 @@ class OnlineCTCEndpoint:
...
@@ -76,14 +77,23 @@ class OnlineCTCEndpoint:
)
and
trailine_silence
>=
rule
.
min_trailing_silence
and
utterance_length
>=
rule
.
min_utterance_length
)
and
trailine_silence
>=
rule
.
min_trailing_silence
and
utterance_length
>=
rule
.
min_utterance_length
if
(
ans
):
if
(
ans
):
logger
.
info
(
logger
.
info
(
f
"Endpoint Rule:
{
rule_name
}
activated:
{
decoding_something
}
,
{
trailine_silence
}
,
{
utterance_length
}
"
f
"Endpoint Rule:
{
rule_name
}
activated:
{
rule
}
"
)
)
return
ans
return
ans
def
endpoint_detected
(
ctc_log_probs
:
List
[
List
[
float
]]
,
def
endpoint_detected
(
self
,
ctc_log_probs
:
np
.
ndarray
,
decoding_something
:
bool
)
->
bool
:
decoding_something
:
bool
)
->
bool
:
"""detect endpoint.
Args:
ctc_log_probs (np.ndarray): (T, D)
decoding_something (bool): contain nonsilince.
Returns:
bool: whether endpoint detected.
"""
for
logprob
in
ctc_log_probs
:
for
logprob
in
ctc_log_probs
:
blank_prob
=
exp
(
logprob
[
self
.
opts
.
blank_id
])
blank_prob
=
np
.
exp
(
logprob
[
self
.
opts
.
blank
])
self
.
num_frames_decoded
+=
1
self
.
num_frames_decoded
+=
1
if
blank_prob
>
self
.
opts
.
blank_threshold
:
if
blank_prob
>
self
.
opts
.
blank_threshold
:
...
@@ -96,6 +106,7 @@ class OnlineCTCEndpoint:
...
@@ -96,6 +106,7 @@ class OnlineCTCEndpoint:
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
if
self
.
rule_activated
(
self
.
opts
.
rule1
,
'rule1'
,
decoding_something
,
if
self
.
rule_activated
(
self
.
opts
.
rule1
,
'rule1'
,
decoding_something
,
trailing_silence
,
utterance_length
):
trailing_silence
,
utterance_length
):
return
True
return
True
...
...
paddlespeech/server/ws/asr_api.py
浏览文件 @
69a6da4c
...
@@ -38,7 +38,7 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -38,7 +38,7 @@ async def websocket_endpoint(websocket: WebSocket):
#2. if we accept the websocket headers, we will get the online asr engine instance
#2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool
=
get_engine_pool
()
engine_pool
=
get_engine_pool
()
asr_
engine
=
engine_pool
[
'asr'
]
asr_
model
=
engine_pool
[
'asr'
]
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request
# and each connection has its own connection instance to process the request
...
@@ -70,7 +70,8 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -70,7 +70,8 @@ async def websocket_endpoint(websocket: WebSocket):
resp
=
{
"status"
:
"ok"
,
"signal"
:
"server_ready"
}
resp
=
{
"status"
:
"ok"
,
"signal"
:
"server_ready"
}
# do something at begining here
# do something at begining here
# create the instance to process the audio
# create the instance to process the audio
connection_handler
=
PaddleASRConnectionHanddler
(
asr_engine
)
#connection_handler = PaddleASRConnectionHanddler(asr_model)
connection_handler
=
asr_model
.
new_handler
()
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
elif
message
[
'signal'
]
==
'end'
:
elif
message
[
'signal'
]
==
'end'
:
# reset single engine for an new connection
# reset single engine for an new connection
...
@@ -100,11 +101,34 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -100,11 +101,34 @@ async def websocket_endpoint(websocket: WebSocket):
# and decode for the result in this package data
# and decode for the result in this package data
connection_handler
.
extract_feat
(
message
)
connection_handler
.
extract_feat
(
message
)
connection_handler
.
decode
(
is_finished
=
False
)
connection_handler
.
decode
(
is_finished
=
False
)
if
connection_handler
.
endpoint_state
:
logger
.
info
(
"endpoint: detected and rescoring."
)
connection_handler
.
rescoring
()
word_time_stamp
=
connection_handler
.
get_word_time_stamp
()
asr_results
=
connection_handler
.
get_result
()
asr_results
=
connection_handler
.
get_result
()
# return the current period result
if
connection_handler
.
endpoint_state
:
# if the engine create the vad instance, this connection will have many period results
if
connection_handler
.
continuous_decoding
:
logger
.
info
(
"endpoint: continue decoding"
)
connection_handler
.
reset_continuous_decoding
()
else
:
logger
.
info
(
"endpoint: exit decoding"
)
# ending by endpoint
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'result'
:
asr_results
,
'times'
:
word_time_stamp
}
await
websocket
.
send_json
(
resp
)
break
# return the current partial result
# if the engine create the vad instance, this connection will have many partial results
resp
=
{
'result'
:
asr_results
}
resp
=
{
'result'
:
asr_results
}
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
except
WebSocketDisconnect
as
e
:
except
WebSocketDisconnect
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录