Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d847fe29
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d847fe29
编写于
3月 30, 2022
作者:
W
WilliamZhang06
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added online asr engine , test=doc
上级
6f0b3a15
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
996 addition
and
4 deletion
+996
-4
paddlespeech/s2t/frontend/audio.py
paddlespeech/s2t/frontend/audio.py
+12
-0
paddlespeech/s2t/frontend/speech.py
paddlespeech/s2t/frontend/speech.py
+16
-0
paddlespeech/server/bin/main.py
paddlespeech/server/bin/main.py
+8
-2
paddlespeech/server/conf/application.yaml
paddlespeech/server/conf/application.yaml
+25
-2
paddlespeech/server/engine/asr/online/__init__.py
paddlespeech/server/engine/asr/online/__init__.py
+13
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+355
-0
paddlespeech/server/engine/engine_factory.py
paddlespeech/server/engine/engine_factory.py
+3
-0
paddlespeech/server/tests/asr/online/microphone_client.py
paddlespeech/server/tests/asr/online/microphone_client.py
+154
-0
paddlespeech/server/tests/asr/online/websocket_client.py
paddlespeech/server/tests/asr/online/websocket_client.py
+115
-0
paddlespeech/server/utils/buffer.py
paddlespeech/server/utils/buffer.py
+59
-0
paddlespeech/server/utils/vad.py
paddlespeech/server/utils/vad.py
+79
-0
paddlespeech/server/ws/__init__.py
paddlespeech/server/ws/__init__.py
+13
-0
paddlespeech/server/ws/api.py
paddlespeech/server/ws/api.py
+38
-0
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+106
-0
未找到文件。
paddlespeech/s2t/frontend/audio.py
浏览文件 @
d847fe29
...
...
@@ -208,6 +208,18 @@ class AudioSegment():
io
.
BytesIO
(
bytes
),
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
from_pcm
(
cls
,
samples
,
sample_rate
):
"""Create audio segment from a byte string containing audio samples.
:param samples: Audio samples [num_samples x num_channels].
:type samples: numpy.ndarray
:param sample_rate: Audio sample rate.
:type sample_rate: int
:return: Audio segment instance.
:rtype: AudioSegment
"""
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
concatenate
(
cls
,
*
segments
):
"""Concatenate an arbitrary number of audio segments together.
...
...
paddlespeech/s2t/frontend/speech.py
浏览文件 @
d847fe29
...
...
@@ -107,6 +107,22 @@ class SpeechSegment(AudioSegment):
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
,
tokens
,
token_ids
)
@
classmethod
def
from_pcm
(
cls
,
samples
,
sample_rate
,
transcript
,
tokens
=
None
,
token_ids
=
None
):
"""Create speech segment from pcm on online mode
Args:
samples (numpy.ndarray): Audio samples [num_samples x num_channels].
sample_rate (int): Audio sample rate.
transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
Returns:
SpeechSegment: Speech segment instance.
"""
audio
=
AudioSegment
.
from_pcm
(
samples
,
sample_rate
)
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
,
tokens
,
token_ids
)
@
classmethod
def
concatenate
(
cls
,
*
segments
):
"""Concatenate an arbitrary number of speech segments together, both
...
...
paddlespeech/server/bin/main.py
浏览文件 @
d847fe29
...
...
@@ -17,7 +17,8 @@ import uvicorn
from
fastapi
import
FastAPI
from
paddlespeech.server.engine.engine_pool
import
init_engine_pool
from
paddlespeech.server.restful.api
import
setup_router
from
paddlespeech.server.restful.api
import
setup_router
as
setup_http_router
from
paddlespeech.server.ws.api
import
setup_router
as
setup_ws_router
from
paddlespeech.server.utils.config
import
get_config
app
=
FastAPI
(
...
...
@@ -35,7 +36,12 @@ def init(config):
"""
# init api
api_list
=
list
(
engine
.
split
(
"_"
)[
0
]
for
engine
in
config
.
engine_list
)
api_router
=
setup_router
(
api_list
)
if
config
.
protocol
==
"websocket"
:
api_router
=
setup_ws_router
(
api_list
)
elif
config
.
protocol
==
"http"
:
api_router
=
setup_http_router
(
api_list
)
else
:
raise
Exception
(
"unsupported protocol"
)
app
.
include_router
(
api_router
)
if
not
init_engine_pool
(
config
):
...
...
paddlespeech/server/conf/application.yaml
浏览文件 @
d847fe29
...
...
@@ -3,13 +3,18 @@
#################################################################################
# SERVER SETTING #
#################################################################################
host
:
127.0.0.1
host
:
0.0.0.0
port
:
8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
# protocol: 'http'
# engine_list: ['asr_python', 'tts_python', 'cls_python']
engine_list
:
[
'
asr_python'
,
'
tts_python'
,
'
cls_python'
]
# websocket, http (only choose one). websocket only support online engine type.
protocol
:
'
websocket'
engine_list
:
[
'
asr_online'
]
#################################################################################
...
...
@@ -48,6 +53,24 @@ asr_inference:
summary
:
True
# False -> do not show predictor config
################### speech task: asr; engine_type: online #######################
asr_online
:
model_type
:
'
deepspeech2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
sample_rate
:
16000
cfg_path
:
decode_method
:
force_yes
:
True
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
glog_info
:
False
# True -> print glog
summary
:
True
# False -> do not show predictor config
################################### TTS #########################################
################### speech task: tts; engine_type: python #######################
tts_python
:
...
...
paddlespeech/server/engine/asr/online/__init__.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/engine/asr/online/asr_engine.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
io
import
os
import
time
from
typing
import
Optional
import
pickle
import
numpy
as
np
from
numpy
import
float32
import
soundfile
import
paddle
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.frontend.speech
import
SpeechSegment
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.config
import
get_config
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
from
paddlespeech.server.utils.paddle_predictor
import
run_model
__all__
=
[
'ASREngine'
]
pretrained_models
=
{
"deepspeech2online_aishell-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz'
,
'md5'
:
'd5e076217cf60486519f72c217d21b9b'
,
'cfg_path'
:
'model.yaml'
,
'ckpt_path'
:
'exp/deepspeech2_online/checkpoints/avg_1'
,
'model'
:
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel'
,
'params'
:
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams'
,
'lm_url'
:
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
}
class
ASRServerExecutor
(
ASRExecutor
):
def
__init__
(
self
):
super
().
__init__
()
pass
def
_init_from_path
(
self
,
model_type
:
str
=
'wenetspeech'
,
am_model
:
Optional
[
os
.
PathLike
]
=
None
,
am_params
:
Optional
[
os
.
PathLike
]
=
None
,
lang
:
str
=
'zh'
,
sample_rate
:
int
=
16000
,
cfg_path
:
Optional
[
os
.
PathLike
]
=
None
,
decode_method
:
str
=
'attention_rescoring'
,
am_predictor_conf
:
dict
=
None
):
"""
Init model and other resources from a specific path.
"""
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
res_path
=
res_path
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'cfg_path'
])
self
.
am_model
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'model'
])
self
.
am_params
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'params'
])
logger
.
info
(
res_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
self
.
am_params
=
os
.
path
.
abspath
(
am_params
)
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
with
UpdateConfig
(
self
.
config
):
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
MODEL_HOME
,
'language_model'
,
self
.
config
.
decode
.
lang_model_path
)
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
self
.
config
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
vocab
)
lm_url
=
pretrained_models
[
tag
][
'lm_url'
]
lm_md5
=
pretrained_models
[
tag
][
'lm_md5'
]
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
raise
Exception
(
"wrong type"
)
else
:
raise
Exception
(
"wrong type"
)
# AM predictor
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
# decoder
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
config
.
rnn_layer_size
*
2
,
blank_id
=
self
.
config
.
blank_id
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
self
.
config
.
get
(
'ctc_grad_norm_type'
,
None
))
# init decoder
cfg
=
self
.
config
.
decode
decode_batch_size
=
1
# for online
self
.
decoder
.
init_decoder
(
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# init state box
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
def
reset_decoder_and_chunk
(
self
):
"""reset decoder and chunk state for an new audio
"""
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
# init state box, for new audio request
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
,
model_type
:
str
):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
[type]: [description]
"""
if
"deepspeech2online"
in
model_type
:
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
3
])
audio_handle
.
reshape
(
x_chunk
.
shape
)
audio_handle
.
copy_from_cpu
(
x_chunk
)
audio_len_handle
.
reshape
(
x_chunk_lens
.
shape
)
audio_len_handle
.
copy_from_cpu
(
x_chunk_lens
)
h_box_handle
.
reshape
(
self
.
chunk_state_h_box
.
shape
)
h_box_handle
.
copy_from_cpu
(
self
.
chunk_state_h_box
)
c_box_handle
.
reshape
(
self
.
chunk_state_c_box
.
shape
)
c_box_handle
.
copy_from_cpu
(
self
.
chunk_state_c_box
)
output_names
=
self
.
am_predictor
.
get_output_names
()
output_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
1
])
output_state_h_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
2
])
output_state_c_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
3
])
self
.
am_predictor
.
run
()
output_chunk_probs
=
output_handle
.
copy_to_cpu
()
output_chunk_lens
=
output_lens_handle
.
copy_to_cpu
()
self
.
chunk_state_h_box
=
output_state_h_handle
.
copy_to_cpu
()
self
.
chunk_state_c_box
=
output_state_c_handle
.
copy_to_cpu
()
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
return
trans_best
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
raise
Exception
(
"invalid model name"
)
else
:
raise
Exception
(
"invalid model name"
)
def
_pcm16to32
(
self
,
audio
):
"""pcm int16 to float32
Args:
audio(numpy.array): numpy.int16
Returns:
audio(numpy.array): numpy.float32
"""
if
audio
.
dtype
==
np
.
int16
:
audio
=
audio
.
astype
(
"float32"
)
bits
=
np
.
iinfo
(
np
.
int16
).
bits
audio
=
audio
/
(
2
**
(
bits
-
1
))
return
audio
def
extract_feat
(
self
,
samples
,
sample_rate
):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# pcm16 -> pcm 32
samples
=
self
.
_pcm16to32
(
samples
)
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
sample_rate
,
transcript
=
" "
)
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
x_chunk
=
audio
.
numpy
()
x_chunk_lens
=
np
.
array
([
audio_len
])
return
x_chunk
,
x_chunk_lens
class
ASREngine
(
BaseEngine
):
"""ASR server engine
Args:
metaclass: Defaults to Singleton.
"""
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
def
init
(
self
,
config
:
dict
)
->
bool
:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self
.
input
=
None
self
.
output
=
""
self
.
executor
=
ASRServerExecutor
()
self
.
config
=
config
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model_type
,
am_model
=
self
.
config
.
am_model
,
am_params
=
self
.
config
.
am_params
,
lang
=
self
.
config
.
lang
,
sample_rate
=
self
.
config
.
sample_rate
,
cfg_path
=
self
.
config
.
cfg_path
,
decode_method
=
self
.
config
.
decode_method
,
am_predictor_conf
=
self
.
config
.
am_predictor_conf
)
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
def
preprocess
(
self
,
samples
,
sample_rate
):
"""preprocess
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
x_chunk
,
x_chunk_lens
=
self
.
executor
.
extract_feat
(
samples
,
sample_rate
)
return
x_chunk
,
x_chunk_lens
def
run
(
self
,
x_chunk
,
x_chunk_lens
,
decoder_chunk_size
=
1
):
"""run online engine
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self
.
output
=
self
.
executor
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
,
self
.
config
.
model_type
)
def
postprocess
(
self
):
"""postprocess
"""
return
self
.
output
def
reset
(
self
):
"""reset engine decoder and inference state
"""
self
.
executor
.
reset_decoder_and_chunk
()
self
.
output
=
""
paddlespeech/server/engine/engine_factory.py
浏览文件 @
d847fe29
...
...
@@ -25,6 +25,9 @@ class EngineFactory(object):
elif
engine_name
==
'asr'
and
engine_type
==
'python'
:
from
paddlespeech.server.engine.asr.python.asr_engine
import
ASREngine
return
ASREngine
()
elif
engine_name
==
'asr'
and
engine_type
==
'online'
:
from
paddlespeech.server.engine.asr.online.asr_engine
import
ASREngine
return
ASREngine
()
elif
engine_name
==
'tts'
and
engine_type
==
'inference'
:
from
paddlespeech.server.engine.tts.paddleinference.tts_engine
import
TTSEngine
return
TTSEngine
()
...
...
paddlespeech/server/tests/asr/online/microphone_client.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
record wave from the mic
"""
import
threading
import
pyaudio
import
wave
import
logging
import
asyncio
import
websockets
import
json
from
signal
import
SIGINT
,
SIGTERM
class
ASRAudioHandler
(
threading
.
Thread
):
def
__init__
(
self
,
url
=
"127.0.0.1"
,
port
=
8090
):
threading
.
Thread
.
__init__
(
self
)
self
.
url
=
url
self
.
port
=
port
self
.
url
=
"ws://"
+
self
.
url
+
":"
+
str
(
self
.
port
)
+
"/ws/asr"
self
.
fileName
=
"./output.wav"
self
.
chunk
=
5120
self
.
format
=
pyaudio
.
paInt16
self
.
channels
=
1
self
.
rate
=
16000
self
.
_running
=
True
self
.
_frames
=
[]
self
.
data_backup
=
[]
def
startrecord
(
self
):
"""
start a new thread to record wave
"""
threading
.
_start_new_thread
(
self
.
recording
,
())
def
recording
(
self
):
"""
recording wave
"""
self
.
_running
=
True
self
.
_frames
=
[]
p
=
pyaudio
.
PyAudio
()
stream
=
p
.
open
(
format
=
self
.
format
,
channels
=
self
.
channels
,
rate
=
self
.
rate
,
input
=
True
,
frames_per_buffer
=
self
.
chunk
)
while
(
self
.
_running
):
data
=
stream
.
read
(
self
.
chunk
)
self
.
_frames
.
append
(
data
)
self
.
data_backup
.
append
(
data
)
stream
.
stop_stream
()
stream
.
close
()
p
.
terminate
()
def
save
(
self
):
"""
save wave data
"""
p
=
pyaudio
.
PyAudio
()
wf
=
wave
.
open
(
self
.
fileName
,
'wb'
)
wf
.
setnchannels
(
self
.
channels
)
wf
.
setsampwidth
(
p
.
get_sample_size
(
self
.
format
))
wf
.
setframerate
(
self
.
rate
)
wf
.
writeframes
(
b
''
.
join
(
self
.
data_backup
))
wf
.
close
()
p
.
terminate
()
def
stoprecord
(
self
):
"""
stop recording
"""
self
.
_running
=
False
async
def
run
(
self
):
aa
=
input
(
"是否开始录音? (y/n)"
)
if
aa
.
strip
()
==
"y"
:
self
.
startrecord
()
logging
.
info
(
"*"
*
10
+
"开始录音,请输入语音"
)
async
with
websockets
.
connect
(
self
.
url
)
as
ws
:
# 发送开始指令
audio_info
=
json
.
dumps
({
"name"
:
"test.wav"
,
"signal"
:
"start"
,
"nbest"
:
5
},
sort_keys
=
True
,
indent
=
4
,
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
# send bytes data
logging
.
info
(
"结束录音请: Ctrl + c。继续请按回车。"
)
try
:
while
True
:
while
len
(
self
.
_frames
)
>
0
:
await
ws
.
send
(
self
.
_frames
.
pop
(
0
))
msg
=
await
ws
.
recv
()
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
except
asyncio
.
CancelledError
:
# quit
# send finished
audio_info
=
json
.
dumps
({
"name"
:
"test.wav"
,
"signal"
:
"end"
,
"nbest"
:
5
},
sort_keys
=
True
,
indent
=
4
,
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
self
.
stoprecord
()
logging
.
info
(
"*"
*
10
+
"录音结束"
)
self
.
save
()
elif
aa
.
strip
()
==
"n"
:
exit
()
else
:
print
(
"无效输入!"
)
exit
()
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
info
(
"asr websocket client start"
)
handler
=
ASRAudioHandler
(
"127.0.0.1"
,
8090
)
loop
=
asyncio
.
get_event_loop
()
main_task
=
asyncio
.
ensure_future
(
handler
.
run
())
for
signal
in
[
SIGINT
,
SIGTERM
]:
loop
.
add_signal_handler
(
signal
,
main_task
.
cancel
)
try
:
loop
.
run_until_complete
(
main_task
)
finally
:
loop
.
close
()
logging
.
info
(
"asr websocket client finished"
)
paddlespeech/server/tests/asr/online/websocket_client.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import
argparse
import
logging
import
time
import
os
import
json
import
wave
import
numpy
as
np
import
asyncio
import
websockets
import
soundfile
class
ASRAudioHandler
:
def
__init__
(
self
,
url
=
"127.0.0.1"
,
port
=
8090
):
self
.
url
=
url
self
.
port
=
port
self
.
url
=
"ws://"
+
self
.
url
+
":"
+
str
(
self
.
port
)
+
"/ws/asr"
def
read_wave
(
self
,
wavfile_path
:
str
):
samples
,
sample_rate
=
soundfile
.
read
(
wavfile_path
,
dtype
=
'int16'
)
x_len
=
len
(
samples
)
chunk_stride
=
40
*
16
#40ms, sample_rate = 16kHz
chunk_size
=
80
*
16
#80ms, sample_rate = 16kHz
if
(
x_len
-
chunk_size
)
%
chunk_stride
!=
0
:
padding_len_x
=
chunk_stride
-
(
x_len
-
chunk_size
)
%
chunk_stride
else
:
padding_len_x
=
0
padding
=
np
.
zeros
(
(
padding_len_x
),
dtype
=
samples
.
dtype
)
padded_x
=
np
.
concatenate
([
samples
,
padding
],
axis
=
0
)
num_chunk
=
(
x_len
+
padding_len_x
-
chunk_size
)
/
chunk_stride
+
1
num_chunk
=
int
(
num_chunk
)
for
i
in
range
(
0
,
num_chunk
):
start
=
i
*
chunk_stride
end
=
start
+
chunk_size
x_chunk
=
padded_x
[
start
:
end
]
yield
x_chunk
async
def
run
(
self
,
wavfile_path
:
str
):
logging
.
info
(
"send a message to the server"
)
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async
with
websockets
.
connect
(
self
.
url
)
as
ws
:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info
=
json
.
dumps
({
"name"
:
"test.wav"
,
"signal"
:
"start"
,
"nbest"
:
5
},
sort_keys
=
True
,
indent
=
4
,
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
# send chunk audio data to engine
for
chunk_data
in
self
.
read_wave
(
wavfile_path
):
await
ws
.
send
(
chunk_data
.
tobytes
())
msg
=
await
ws
.
recv
()
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
# finished
audio_info
=
json
.
dumps
({
"name"
:
"test.wav"
,
"signal"
:
"end"
,
"nbest"
:
5
},
sort_keys
=
True
,
indent
=
4
,
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
logging
.
info
(
"receive msg={}"
.
format
(
msg
))
def
main
(
args
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
info
(
"asr websocket client start"
)
handler
=
ASRAudioHandler
(
"127.0.0.1"
,
8090
)
loop
=
asyncio
.
get_event_loop
()
loop
.
run_until_complete
(
handler
.
run
(
args
.
wavfile
))
logging
.
info
(
"asr websocket client finished"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--wavfile"
,
action
=
"store"
,
help
=
"wav file path "
,
default
=
"./16_audio.wav"
)
args
=
parser
.
parse_args
()
main
(
args
)
paddlespeech/server/utils/buffer.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
Frame
(
object
):
"""Represents a "frame" of audio data."""
def
__init__
(
self
,
bytes
,
timestamp
,
duration
):
self
.
bytes
=
bytes
self
.
timestamp
=
timestamp
self
.
duration
=
duration
class
ChunkBuffer
(
object
):
def
__init__
(
self
,
frame_duration_ms
=
80
,
shift_ms
=
40
,
sample_rate
=
16000
,
sample_width
=
2
):
self
.
sample_rate
=
sample_rate
self
.
frame_duration_ms
=
frame_duration_ms
self
.
shift_ms
=
shift_ms
self
.
remained_audio
=
b
''
self
.
sample_width
=
sample_width
# int16 = 2; float32 = 4
def
frame_generator
(
self
,
audio
):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate.
Yields Frames of the requested duration.
"""
audio
=
self
.
remained_audio
+
audio
self
.
remained_audio
=
b
''
n
=
int
(
self
.
sample_rate
*
(
self
.
frame_duration_ms
/
1000.0
)
*
self
.
sample_width
)
shift_n
=
int
(
self
.
sample_rate
*
(
self
.
shift_ms
/
1000.0
)
*
self
.
sample_width
)
offset
=
0
timestamp
=
0.0
duration
=
(
float
(
n
)
/
self
.
sample_rate
)
/
self
.
sample_width
shift_duration
=
(
float
(
shift_n
)
/
self
.
sample_rate
)
/
self
.
sample_width
while
offset
+
n
<=
len
(
audio
):
yield
Frame
(
audio
[
offset
:
offset
+
n
],
timestamp
,
duration
)
timestamp
+=
shift_duration
offset
+=
shift_n
self
.
remained_audio
+=
audio
[
offset
:]
paddlespeech/server/utils/vad.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
import
logging
import
webrtcvad
class
VADAudio
():
def
__init__
(
self
,
aggressiveness
,
rate
,
frame_duration_ms
,
sample_width
=
2
,
padding_ms
=
200
,
padding_ratio
=
0.9
):
"""Initializes VAD with given aggressivenes and sets up internal queues"""
self
.
vad
=
webrtcvad
.
Vad
(
aggressiveness
)
self
.
rate
=
rate
self
.
sample_width
=
sample_width
self
.
frame_duration_ms
=
frame_duration_ms
self
.
_frame_length
=
int
(
rate
*
(
frame_duration_ms
/
1000.0
)
*
self
.
sample_width
)
self
.
_buffer_queue
=
collections
.
deque
()
self
.
ring_buffer
=
collections
.
deque
(
maxlen
=
padding_ms
//
frame_duration_ms
)
self
.
_ratio
=
padding_ratio
self
.
triggered
=
False
def
add_audio
(
self
,
audio
):
"""Adds new audio to internal queue"""
for
x
in
audio
:
self
.
_buffer_queue
.
append
(
x
)
def
frame_generator
(
self
):
"""Generator that yields audio frames of frame_duration_ms"""
while
len
(
self
.
_buffer_queue
)
>
self
.
_frame_length
:
frame
=
bytearray
()
for
_
in
range
(
self
.
_frame_length
):
frame
.
append
(
self
.
_buffer_queue
.
popleft
())
yield
bytes
(
frame
)
def
vad_collector
(
self
):
"""Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None.
Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered.
Example: (frame, ..., frame, None, frame, ..., frame, None, ...)
|---utterence---| |---utterence---|
"""
for
frame
in
self
.
frame_generator
():
is_speech
=
self
.
vad
.
is_speech
(
frame
,
self
.
rate
)
if
not
self
.
triggered
:
self
.
ring_buffer
.
append
((
frame
,
is_speech
))
num_voiced
=
len
(
[
f
for
f
,
speech
in
self
.
ring_buffer
if
speech
])
if
num_voiced
>
self
.
_ratio
*
self
.
ring_buffer
.
maxlen
:
self
.
triggered
=
True
for
f
,
s
in
self
.
ring_buffer
:
yield
f
self
.
ring_buffer
.
clear
()
else
:
yield
frame
self
.
ring_buffer
.
append
((
frame
,
is_speech
))
num_unvoiced
=
len
(
[
f
for
f
,
speech
in
self
.
ring_buffer
if
not
speech
])
if
num_unvoiced
>
self
.
_ratio
*
self
.
ring_buffer
.
maxlen
:
self
.
triggered
=
False
yield
None
self
.
ring_buffer
.
clear
()
paddlespeech/server/ws/__init__.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/ws/api.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
List
from
fastapi
import
APIRouter
from
paddlespeech.server.ws.asr_socket
import
router
as
asr_router
_router
=
APIRouter
()
def
setup_router
(
api_list
:
List
):
"""setup router for fastapi
Args:
api_list (List): [asr, tts]
Returns:
APIRouter
"""
for
api_name
in
api_list
:
if
api_name
==
'asr'
:
_router
.
include_router
(
asr_router
)
elif
api_name
==
'tts'
:
pass
else
:
pass
return
_router
paddlespeech/server/ws/asr_socket.py
0 → 100644
浏览文件 @
d847fe29
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
base64
import
traceback
from
typing
import
Union
import
random
import
numpy
as
np
import
json
from
fastapi
import
APIRouter
from
fastapi
import
WebSocket
from
fastapi
import
WebSocketDisconnect
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
paddlespeech.server.engine.asr.online.asr_engine
import
ASREngine
from
paddlespeech.server.engine.engine_pool
import
get_engine_pool
from
paddlespeech.server.utils.buffer
import
ChunkBuffer
from
paddlespeech.server.utils.vad
import
VADAudio
router
=
APIRouter
()
@
router
.
websocket
(
'/ws/asr'
)
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
await
websocket
.
accept
()
# init buffer
chunk_buffer
=
ChunkBuffer
(
sample_width
=
2
)
# init vad
vad
=
VADAudio
(
2
,
16000
,
20
)
try
:
while
True
:
# careful here, changed the source code from starlette.websockets
assert
websocket
.
application_state
==
WebSocketState
.
CONNECTED
message
=
await
websocket
.
receive
()
websocket
.
_raise_on_disconnect
(
message
)
if
"text"
in
message
:
message
=
json
.
loads
(
message
[
"text"
])
if
'signal'
not
in
message
:
resp
=
{
"status"
:
"ok"
,
"message"
:
"no valid json data"
}
await
websocket
.
send_json
(
resp
)
if
message
[
'signal'
]
==
'start'
:
resp
=
{
"status"
:
"ok"
,
"signal"
:
"server_ready"
}
# do something at begining here
await
websocket
.
send_json
(
resp
)
elif
message
[
'signal'
]
==
'end'
:
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
# reset single engine for an new connection
asr_engine
.
reset
()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
}
await
websocket
.
send_json
(
resp
)
break
else
:
resp
=
{
"status"
:
"ok"
,
"message"
:
"no valid json data"
}
await
websocket
.
send_json
(
resp
)
elif
"bytes"
in
message
:
message
=
message
[
"bytes"
]
# vad for input bytes audio
vad
.
add_audio
(
message
)
message
=
b
''
.
join
(
f
for
f
in
vad
.
vad_collector
()
if
f
is
not
None
)
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_results
=
""
frames
=
chunk_buffer
.
frame_generator
(
message
)
for
frame
in
frames
:
samples
=
np
.
frombuffer
(
frame
.
bytes
,
dtype
=
np
.
int16
)
sample_rate
=
asr_engine
.
config
.
sample_rate
x_chunk
,
x_chunk_lens
=
asr_engine
.
preprocess
(
samples
,
sample_rate
)
asr_engine
.
run
(
x_chunk
,
x_chunk_lens
)
asr_results
=
asr_engine
.
postprocess
()
asr_results
=
asr_engine
.
postprocess
()
resp
=
{
'asr_results'
:
asr_results
}
await
websocket
.
send_json
(
resp
)
except
WebSocketDisconnect
:
pass
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录