Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
cf9a590f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
cf9a590f
编写于
4月 20, 2022
作者:
H
Hui Zhang
提交者:
GitHub
4月 20, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1704 from Honei/server
[asr][websocket] add asr conformer websocket server
上级
0d0aabe2
ac9fcf7f
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
1154 addition
and
147 deletion
+1154
-147
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+24
-17
paddlespeech/cli/asr/pretrained_models.py
paddlespeech/cli/asr/pretrained_models.py
+2
-0
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+3
-4
paddlespeech/s2t/modules/ctc.py
paddlespeech/s2t/modules/ctc.py
+1
-1
paddlespeech/server/README.md
paddlespeech/server/README.md
+13
-0
paddlespeech/server/README_cn.md
paddlespeech/server/README_cn.md
+14
-0
paddlespeech/server/bin/paddlespeech_client.py
paddlespeech/server/bin/paddlespeech_client.py
+4
-2
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+4
-8
paddlespeech/server/conf/ws_conformer_application.yaml
paddlespeech/server/conf/ws_conformer_application.yaml
+45
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+821
-76
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+128
-0
paddlespeech/server/tests/__init__.py
paddlespeech/server/tests/__init__.py
+13
-0
paddlespeech/server/tests/asr/__init__.py
paddlespeech/server/tests/asr/__init__.py
+13
-0
paddlespeech/server/tests/asr/offline/__init__.py
paddlespeech/server/tests/asr/offline/__init__.py
+13
-0
paddlespeech/server/tests/asr/online/__init__.py
paddlespeech/server/tests/asr/online/__init__.py
+13
-0
paddlespeech/server/tests/asr/online/websocket_client.py
paddlespeech/server/tests/asr/online/websocket_client.py
+11
-8
paddlespeech/server/utils/buffer.py
paddlespeech/server/utils/buffer.py
+1
-1
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+31
-30
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
cf9a590f
...
...
@@ -40,7 +40,6 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__
=
[
'ASRExecutor'
]
@
cli_register
(
name
=
'paddlespeech.asr'
,
description
=
'Speech to text infer command.'
)
class
ASRExecutor
(
BaseExecutor
):
...
...
@@ -125,6 +124,7 @@ class ASRExecutor(BaseExecutor):
"""
Init model and other resources from a specific path.
"""
logger
.
info
(
"start to init the model"
)
if
hasattr
(
self
,
'model'
):
logger
.
info
(
'Model had been initialized.'
)
return
...
...
@@ -140,14 +140,15 @@ class ASRExecutor(BaseExecutor):
res_path
,
self
.
pretrained_models
[
tag
][
'ckpt_path'
]
+
".pdparams"
)
logger
.
info
(
res_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
+
".pdparams"
)
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
...
...
@@ -176,7 +177,6 @@ class ASRExecutor(BaseExecutor):
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
config
.
decode
.
decoding_method
=
decode_method
else
:
raise
Exception
(
"wrong type"
)
model_name
=
model_type
[:
model_type
.
rindex
(
...
...
@@ -254,12 +254,14 @@ class ASRExecutor(BaseExecutor):
else
:
raise
Exception
(
"wrong type"
)
logger
.
info
(
"audio feat process success"
)
@
paddle
.
no_grad
()
def
infer
(
self
,
model_type
:
str
):
"""
Model inference and result stored in self.output.
"""
logger
.
info
(
"start to infer the model to get the output"
)
cfg
=
self
.
config
.
decode
audio
=
self
.
_inputs
[
"audio"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
...
...
@@ -276,17 +278,22 @@ class ASRExecutor(BaseExecutor):
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
beam_size
=
cfg
.
beam_size
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
][
0
]
logger
.
info
(
f
"we will use the transformer like model :
{
model_type
}
"
)
try
:
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
beam_size
=
cfg
.
beam_size
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
][
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
raise
Exception
(
"invalid model name"
)
...
...
paddlespeech/cli/asr/pretrained_models.py
浏览文件 @
cf9a590f
...
...
@@ -88,6 +88,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"
,
"conformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"conformer_online"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"transformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"wenetspeech"
:
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
cf9a590f
...
...
@@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if
end_flag
.
cast
(
paddle
.
int64
).
sum
()
==
running_size
:
break
# 2.1 Forward decoder step
hyps_mask
=
subsequent_mask
(
i
).
unsqueeze
(
0
).
repeat
(
running_size
,
1
,
1
).
to
(
device
)
# (B*N, i, i)
# logp: (B*N, vocab)
logp
,
cache
=
self
.
decoder
.
forward_one_step
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_mask
,
cache
)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (B*N, N)
top_k_logp
=
mask_finished_scores
(
top_k_logp
,
end_flag
)
...
...
@@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
batch_size
=
feats
.
shape
[
0
]
if
decoding_method
in
[
'ctc_prefix_beam_search'
,
'attention_rescoring'
]
and
batch_size
>
1
:
logger
.
fatal
(
logger
.
error
(
f
'decoding mode
{
decoding_method
}
must be running with batch_size == 1'
)
logger
.
error
(
f
"current batch_size is
{
batch_size
}
"
)
sys
.
exit
(
1
)
if
decoding_method
==
'attention'
:
hyps
=
self
.
recognize
(
feats
,
...
...
paddlespeech/s2t/modules/ctc.py
浏览文件 @
cf9a590f
...
...
@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once
if
self
.
_ext_scorer
is
not
None
:
return
if
language_model_path
!=
''
:
logger
.
info
(
"begin to initialize the external scorer "
"for decoding"
)
...
...
paddlespeech/server/README.md
浏览文件 @
cf9a590f
...
...
@@ -35,3 +35,16 @@
```
bash
paddlespeech_client cls
--server_ip
127.0.0.1
--port
8090
--input
input.wav
```
## Online ASR Server
### Lanuch online asr server
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### Access online asr server
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```
\ No newline at end of file
paddlespeech/server/README_cn.md
浏览文件 @
cf9a590f
...
...
@@ -35,3 +35,17 @@
```
bash
paddlespeech_client cls
--server_ip
127.0.0.1
--port
8090
--input
input.wav
```
## 流式ASR
### 启动流式语音识别服务
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### 访问流式语音识别服务
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```
\ No newline at end of file
paddlespeech/server/bin/paddlespeech_client.py
浏览文件 @
cf9a590f
...
...
@@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor):
lang
=
lang
,
audio_format
=
audio_format
)
time_end
=
time
.
time
()
logger
.
info
(
res
.
json
()
)
logger
.
info
(
res
)
logger
.
info
(
"Response time %f s."
%
(
time_end
-
time_start
))
return
True
except
Exception
as
e
:
logger
.
error
(
"Failed to speech recognition."
)
logger
.
error
(
e
)
return
False
@
stats_wrapper
...
...
@@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor):
logging
.
info
(
"asr websocket client start"
)
handler
=
ASRAudioHandler
(
server_ip
,
port
)
loop
=
asyncio
.
get_event_loop
()
loop
.
run_until_complete
(
handler
.
run
(
input
))
res
=
loop
.
run_until_complete
(
handler
.
run
(
input
))
logging
.
info
(
"asr websocket client finished"
)
return
res
[
'asr_results'
]
@
cli_client_register
(
name
=
'paddlespeech_client.cls'
,
description
=
'visit cls service'
)
...
...
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
cf9a590f
...
...
@@ -41,11 +41,7 @@ asr_online:
shift_ms
:
40
sample_rate
:
16000
sample_width
:
2
vad_conf
:
aggressiveness
:
2
sample_rate
:
16000
frame_duration_ms
:
20
sample_width
:
2
padding_ms
:
200
padding_ratio
:
0.9
window_n
:
7
# frame
shift_n
:
4
# frame
window_ms
:
20
# ms
shift_ms
:
10
# ms
paddlespeech/server/conf/ws_conformer_application.yaml
0 → 100644
浏览文件 @
cf9a590f
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host
:
0.0.0.0
port
:
8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol
:
'
websocket'
engine_list
:
[
'
asr_online'
]
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online
:
model_type
:
'
conformer_online_multicn'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
sample_rate
:
16000
cfg_path
:
decode_method
:
force_yes
:
True
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
glog_info
:
False
# True -> print glog
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
window_n
:
7
# frame
shift_n
:
4
# frame
window_ms
:
25
# ms
shift_ms
:
10
# ms
sample_rate
:
16000
sample_width
:
2
\ No newline at end of file
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
cf9a590f
此差异已折叠。
点击以展开。
paddlespeech/server/engine/asr/online/ctc_search.py
0 → 100644
浏览文件 @
cf9a590f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
import
paddle
from
paddlespeech.cli.log
import
logger
from
paddlespeech.s2t.utils.utility
import
log_add
__all__
=
[
'CTCPrefixBeamSearch'
]
class
CTCPrefixBeamSearch
:
def
__init__
(
self
,
config
):
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode): _description_
"""
self
.
config
=
config
self
.
reset
()
@
paddle
.
no_grad
()
def
search
(
self
,
ctc_probs
,
device
,
blank_id
=
0
):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0).
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger
.
info
(
"start to ctc prefix search"
)
batch_size
=
1
beam_size
=
self
.
config
.
beam_size
maxlen
=
ctc_probs
.
shape
[
0
]
assert
len
(
ctc_probs
.
shape
)
==
2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if
self
.
cur_hyps
is
None
:
self
.
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
self
.
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
blank_id
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
# Update *s-s -> *ss, - is for blank
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
else
:
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
self
.
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
self
.
cur_hyps
]
logger
.
info
(
"ctc prefix search success"
)
return
self
.
hyps
def
get_one_best_hyps
(
self
):
"""Return the one best result
Returns:
list: the one best result
"""
return
[
self
.
hyps
[
0
][
0
]]
def
get_hyps
(
self
):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return
self
.
hyps
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
def
finalize_search
(
self
):
"""do nothing in ctc_prefix_beam_search
"""
pass
paddlespeech/server/tests/__init__.py
0 → 100644
浏览文件 @
cf9a590f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/__init__.py
0 → 100644
浏览文件 @
cf9a590f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/offline/__init__.py
0 → 100644
浏览文件 @
cf9a590f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/online/__init__.py
0 → 100644
浏览文件 @
cf9a590f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/tests/asr/online/websocket_client.py
浏览文件 @
cf9a590f
...
...
@@ -34,10 +34,9 @@ class ASRAudioHandler:
def
read_wave
(
self
,
wavfile_path
:
str
):
samples
,
sample_rate
=
soundfile
.
read
(
wavfile_path
,
dtype
=
'int16'
)
x_len
=
len
(
samples
)
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size
=
80
*
16
#80ms, sample_rate = 16kHz
if
x_len
%
chunk_size
!=
0
:
chunk_size
=
85
*
16
#80ms, sample_rate = 16kHz
if
x_len
%
chunk_size
!=
0
:
padding_len_x
=
chunk_size
-
x_len
%
chunk_size
else
:
padding_len_x
=
0
...
...
@@ -48,7 +47,6 @@ class ASRAudioHandler:
assert
(
x_len
+
padding_len_x
)
%
chunk_size
==
0
num_chunk
=
(
x_len
+
padding_len_x
)
/
chunk_size
num_chunk
=
int
(
num_chunk
)
for
i
in
range
(
0
,
num_chunk
):
start
=
i
*
chunk_size
end
=
start
+
chunk_size
...
...
@@ -57,7 +55,11 @@ class ASRAudioHandler:
async
def
run
(
self
,
wavfile_path
:
str
):
logging
.
info
(
"send a message to the server"
)
# self.read_wave()
# send websocket handshake protocal
async
with
websockets
.
connect
(
self
.
url
)
as
ws
:
# server has already received handshake protocal
# client start to send the command
audio_info
=
json
.
dumps
(
{
"name"
:
"test.wav"
,
...
...
@@ -78,7 +80,6 @@ class ASRAudioHandler:
msg
=
json
.
loads
(
msg
)
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
result
=
msg
# finished
audio_info
=
json
.
dumps
(
{
...
...
@@ -91,10 +92,12 @@ class ASRAudioHandler:
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
# decode the bytes to str
msg
=
json
.
loads
(
msg
)
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
return
result
logging
.
info
(
"
final
receive msg={}"
.
format
(
msg
))
result
=
msg
return
result
def
main
(
args
):
...
...
paddlespeech/server/utils/buffer.py
浏览文件 @
cf9a590f
...
...
@@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate.
Yields Frames of the requested duration.
"""
audio
=
self
.
remained_audio
+
audio
self
.
remained_audio
=
b
''
offset
=
0
timestamp
=
0.0
while
offset
+
self
.
window_bytes
<=
len
(
audio
):
yield
Frame
(
audio
[
offset
:
offset
+
self
.
window_bytes
],
timestamp
,
self
.
window_sec
)
...
...
paddlespeech/server/ws/asr_socket.py
浏览文件 @
cf9a590f
...
...
@@ -13,12 +13,12 @@
# limitations under the License.
import
json
import
numpy
as
np
from
fastapi
import
APIRouter
from
fastapi
import
WebSocket
from
fastapi
import
WebSocketDisconnect
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
paddlespeech.server.engine.asr.online.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.engine.engine_pool
import
get_engine_pool
from
paddlespeech.server.utils.buffer
import
ChunkBuffer
from
paddlespeech.server.utils.vad
import
VADAudio
...
...
@@ -28,26 +28,29 @@ router = APIRouter()
@
router
.
websocket
(
'/ws/asr'
)
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
await
websocket
.
accept
()
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
connection_handler
=
None
# init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf
=
asr_engine
.
config
.
chunk_buffer_conf
chunk_buffer
=
ChunkBuffer
(
window_n
=
7
,
shift_n
=
4
,
window_ms
=
20
,
shift_ms
=
10
,
sample_rate
=
chunk_buffer_conf
[
'sample_rate'
],
sample_width
=
chunk_buffer_conf
[
'sample_width'
])
window_n
=
chunk_buffer_conf
.
window_n
,
shift_n
=
chunk_buffer_conf
.
shift_n
,
window_ms
=
chunk_buffer_conf
.
window_ms
,
shift_ms
=
chunk_buffer_conf
.
shift_ms
,
sample_rate
=
chunk_buffer_conf
.
sample_rate
,
sample_width
=
chunk_buffer_conf
.
sample_width
)
# init vad
vad_conf
=
asr_engine
.
config
.
vad_conf
vad
=
VADAudio
(
aggressiveness
=
vad_conf
[
'aggressiveness'
],
rate
=
vad_conf
[
'sample_rate'
],
frame_duration_ms
=
vad_conf
[
'frame_duration_ms'
])
vad_conf
=
asr_engine
.
config
.
get
(
'vad_conf'
,
None
)
if
vad_conf
:
vad
=
VADAudio
(
aggressiveness
=
vad_conf
[
'aggressiveness'
],
rate
=
vad_conf
[
'sample_rate'
],
frame_duration_ms
=
vad_conf
[
'frame_duration_ms'
])
try
:
while
True
:
...
...
@@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket):
if
message
[
'signal'
]
==
'start'
:
resp
=
{
"status"
:
"ok"
,
"signal"
:
"server_ready"
}
# do something at begining here
# create the instance to process the audio
connection_handler
=
PaddleASRConnectionHanddler
(
asr_engine
)
await
websocket
.
send_json
(
resp
)
elif
message
[
'signal'
]
==
'end'
:
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
# reset single engine for an new connection
asr_engine
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
}
connection_handler
.
decode
(
is_finished
=
True
)
connection_handler
.
rescoring
()
asr_results
=
connection_handler
.
get_result
()
connection_handler
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
,
'asr_results'
:
asr_results
}
await
websocket
.
send_json
(
resp
)
break
else
:
...
...
@@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
elif
"bytes"
in
message
:
message
=
message
[
"bytes"
]
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_results
=
""
frames
=
chunk_buffer
.
frame_generator
(
message
)
for
frame
in
frames
:
samples
=
np
.
frombuffer
(
frame
.
bytes
,
dtype
=
np
.
int16
)
sample_rate
=
asr_engine
.
config
.
sample_rate
x_chunk
,
x_chunk_lens
=
asr_engine
.
preprocess
(
samples
,
sample_rate
)
asr_engine
.
run
(
x_chunk
,
x_chunk_lens
)
asr_results
=
asr_engine
.
postprocess
()
connection_handler
.
extract_feat
(
message
)
connection_handler
.
decode
(
is_finished
=
False
)
asr_results
=
connection_handler
.
get_result
()
asr_results
=
asr_engine
.
postprocess
()
resp
=
{
'asr_results'
:
asr_results
}
await
websocket
.
send_json
(
resp
)
except
WebSocketDisconnect
:
pass
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录