Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d21ccd02
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d21ccd02
编写于
4月 15, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add conformer online server, test=doc
上级
af484fc9
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
272 addition
and
113 deletion
+272
-113
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+40
-16
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+6
-2
paddlespeech/s2t/modules/ctc.py
paddlespeech/s2t/modules/ctc.py
+2
-1
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+2
-0
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+44
-10
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+160
-71
paddlespeech/server/tests/asr/online/websocket_client.py
paddlespeech/server/tests/asr/online/websocket_client.py
+1
-1
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+17
-12
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
d21ccd02
...
...
@@ -91,6 +91,20 @@ pretrained_models = {
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer2online_aishell-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz'
,
'md5'
:
'4814e52e0fc2fd48899373f95c84b0c9'
,
'cfg_path'
:
'config.yaml'
,
'ckpt_path'
:
'exp/deepspeech2_online/checkpoints/avg_30'
,
'lm_url'
:
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2offline_librispeech-en-16k"
:
{
'url'
:
...
...
@@ -115,6 +129,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"
,
"conformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"conformer2online"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"transformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"wenetspeech"
:
...
...
@@ -219,6 +235,7 @@ class ASRExecutor(BaseExecutor):
"""
Init model and other resources from a specific path.
"""
logger
.
info
(
"start to init the model"
)
if
hasattr
(
self
,
'model'
):
logger
.
info
(
'Model had been initialized.'
)
return
...
...
@@ -233,14 +250,15 @@ class ASRExecutor(BaseExecutor):
self
.
ckpt_path
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'ckpt_path'
]
+
".pdparams"
)
logger
.
info
(
res_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
+
".pdparams"
)
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
...
...
@@ -269,7 +287,6 @@ class ASRExecutor(BaseExecutor):
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
config
.
decode
.
decoding_method
=
decode_method
else
:
raise
Exception
(
"wrong type"
)
model_name
=
model_type
[:
model_type
.
rindex
(
...
...
@@ -347,12 +364,14 @@ class ASRExecutor(BaseExecutor):
else
:
raise
Exception
(
"wrong type"
)
logger
.
info
(
"audio feat process success"
)
@
paddle
.
no_grad
()
def
infer
(
self
,
model_type
:
str
):
"""
Model inference and result stored in self.output.
"""
logger
.
info
(
"start to infer the model to get the output"
)
cfg
=
self
.
config
.
decode
audio
=
self
.
_inputs
[
"audio"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
...
...
@@ -369,17 +388,22 @@ class ASRExecutor(BaseExecutor):
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
beam_size
=
cfg
.
beam_size
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
][
0
]
logger
.
info
(
f
"we will use the transformer like model :
{
model_type
}
"
)
try
:
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
beam_size
=
cfg
.
beam_size
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
][
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
raise
Exception
(
"invalid model name"
)
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
d21ccd02
...
...
@@ -213,12 +213,14 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
else
:
print
(
"offline decode from the asr"
)
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
print
(
"offline decode success"
)
return
encoder_out
,
encoder_mask
def
recognize
(
...
...
@@ -706,13 +708,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
List[List[int]]: transcripts.
"""
batch_size
=
feats
.
shape
[
0
]
print
(
"start to decode the audio feat"
)
if
decoding_method
in
[
'ctc_prefix_beam_search'
,
'attention_rescoring'
]
and
batch_size
>
1
:
logger
.
fatal
(
logger
.
error
(
f
'decoding mode
{
decoding_method
}
must be running with batch_size == 1'
)
logger
.
error
(
f
"current batch_size is
{
batch_size
}
"
)
sys
.
exit
(
1
)
print
(
f
"use the
{
decoding_method
}
to decode the audio feat"
)
if
decoding_method
==
'attention'
:
hyps
=
self
.
recognize
(
feats
,
...
...
paddlespeech/s2t/modules/ctc.py
浏览文件 @
d21ccd02
...
...
@@ -180,7 +180,8 @@ class CTCDecoder(CTCDecoderBase):
# init once
if
self
.
_ext_scorer
is
not
None
:
return
from
paddlespeech.s2t.decoders.ctcdecoder
import
Scorer
# noqa: F401
if
language_model_path
!=
''
:
logger
.
info
(
"begin to initialize the external scorer "
"for decoding"
)
...
...
paddlespeech/s2t/modules/encoder.py
浏览文件 @
d21ccd02
...
...
@@ -317,6 +317,8 @@ class BaseEncoder(nn.Layer):
outputs
=
[]
offset
=
0
# Feed forward overlap input step by step
print
(
f
"context:
{
context
}
"
)
print
(
f
"stride:
{
stride
}
"
)
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
...
...
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
d21ccd02
...
...
@@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host
:
0.0.0.0
port
:
809
1
port
:
809
6
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
...
...
@@ -18,10 +18,44 @@ engine_list: ['asr_online']
# ENGINE CONFIG #
#################################################################################
# ################################### ASR #########################################
# ################### speech task: asr; engine_type: online #######################
# asr_online:
# model_type: 'deepspeech2online_aishell'
# am_model: # the pdmodel file of am static model [optional]
# am_params: # the pdiparams file of am static model [optional]
# lang: 'zh'
# sample_rate: 16000
# cfg_path:
# decode_method:
# force_yes: True
# am_predictor_conf:
# device: # set 'gpu:id' or 'cpu'
# switch_ir_optim: True
# glog_info: False # True -> print glog
# summary: True # False -> do not show predictor config
# chunk_buffer_conf:
# frame_duration_ms: 80
# shift_ms: 40
# sample_rate: 16000
# sample_width: 2
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online
:
model_type
:
'
deepspeech
2online_aishell'
model_type
:
'
conformer
2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
...
...
@@ -37,15 +71,15 @@ asr_online:
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
frame_duration_ms
:
8
0
frame_duration_ms
:
8
5
shift_ms
:
40
sample_rate
:
16000
sample_width
:
2
vad_conf
:
aggressiveness
:
2
sample_rate
:
16000
frame_duration_ms
:
20
sample_width
:
2
padding_ms
:
200
padding_ratio
:
0.9
#
vad_conf:
#
aggressiveness: 2
#
sample_rate: 16000
#
frame_duration_ms: 20
#
sample_width: 2
#
padding_ms: 200
#
padding_ratio: 0.9
\ No newline at end of file
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
d21ccd02
...
...
@@ -20,11 +20,15 @@ from numpy import float32
from
yacs.config
import
CfgNode
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
model_alias
from
paddlespeech.cli.asr.infer
import
pretrained_models
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.speech
import
SpeechSegment
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
...
...
@@ -51,6 +55,24 @@ pretrained_models = {
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer2online_aishell-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz'
,
'md5'
:
'4814e52e0fc2fd48899373f95c84b0c9'
,
'cfg_path'
:
'exp/chunk_conformer//conf/config.yaml'
,
'ckpt_path'
:
'exp/chunk_conformer/checkpoints/avg_30/'
,
'model'
:
'exp/chunk_conformer/checkpoints/avg_30.pdparams'
,
'params'
:
'exp/chunk_conformer/checkpoints/avg_30.pdparams'
,
'lm_url'
:
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
}
...
...
@@ -71,15 +93,17 @@ class ASRServerExecutor(ASRExecutor):
"""
Init model and other resources from a specific path.
"""
self
.
model_type
=
model_type
self
.
sample_rate
=
sample_rate
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
res_path
=
res_path
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'cfg_path'
])
self
.
cfg_path
=
"/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml"
# self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path'])
self
.
am_model
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'model'
])
...
...
@@ -119,49 +143,67 @@ class ASRServerExecutor(ASRExecutor):
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
# 开发 conformer 的流式模型
logger
.
info
(
"start to create the stream conformer asr engine"
)
# 复用cli里面的代码
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
self
.
config
.
vocab_filepath
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
vocab_filepath
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
# update the decoding method
if
decode_method
:
self
.
config
.
decode
.
decoding_method
=
decode_method
else
:
raise
Exception
(
"wrong type"
)
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
# decoder
logger
.
info
(
"ASR engine start to create the ctc decoder instance"
)
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
config
.
rnn_layer_size
*
2
,
blank_id
=
self
.
config
.
blank_id
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
self
.
config
.
get
(
'ctc_grad_norm_type'
,
None
))
# init decoder
logger
.
info
(
"ASR engine start to init the ctc decoder"
)
cfg
=
self
.
config
.
decode
decode_batch_size
=
1
# for online
self
.
decoder
.
init_decoder
(
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# init state box
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
# decoder
logger
.
info
(
"ASR engine start to create the ctc decoder instance"
)
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
config
.
rnn_layer_size
*
2
,
blank_id
=
self
.
config
.
blank_id
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
self
.
config
.
get
(
'ctc_grad_norm_type'
,
None
))
# init decoder
logger
.
info
(
"ASR engine start to init the ctc decoder"
)
cfg
=
self
.
config
.
decode
decode_batch_size
=
1
# for online
self
.
decoder
.
init_decoder
(
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# init state box
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
logger
.
info
(
f
"model name:
{
model_name
}
"
)
model_class
=
dynamic_import
(
model_name
,
model_alias
)
model_conf
=
self
.
config
model
=
model_class
.
from_config
(
model_conf
)
self
.
model
=
model
logger
.
info
(
"create the transformer like model success"
)
def
reset_decoder_and_chunk
(
self
):
"""reset decoder and chunk state for an new audio
...
...
@@ -186,6 +228,7 @@ class ASRServerExecutor(ASRExecutor):
Returns:
[type]: [description]
"""
logger
.
info
(
"start to decoce chunk by chunk"
)
if
"deepspeech2online"
in
model_type
:
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
...
...
@@ -224,10 +267,29 @@ class ASRServerExecutor(ASRExecutor):
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
raise
Exception
(
"invalid model name"
)
try
:
logger
.
info
(
f
"we will use the transformer like model :
{
self
.
model_type
}
"
)
cfg
=
self
.
config
.
decode
result_transcripts
=
self
.
model
.
decode
(
x_chunk
,
x_chunk_lens
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
beam_size
=
cfg
.
beam_size
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
return
result_transcripts
[
0
][
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
raise
Exception
(
"invalid model name"
)
...
...
@@ -244,32 +306,55 @@ class ASRServerExecutor(ASRExecutor):
"""
# pcm16 -> pcm 32
samples
=
pcm2float
(
samples
)
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
sample_rate
,
transcript
=
" "
)
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
x_chunk
=
audio
.
numpy
()
x_chunk_lens
=
np
.
array
([
audio_len
])
return
x_chunk
,
x_chunk_lens
if
"deepspeech2online"
in
self
.
model_type
:
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
sample_rate
,
transcript
=
" "
)
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
x_chunk
=
audio
.
numpy
()
x_chunk_lens
=
np
.
array
([
audio_len
])
return
x_chunk
,
x_chunk_lens
elif
"conformer2online"
in
self
.
model_type
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
\
"the model sample_rate is {self.sample_rate}"
)
logger
.
info
(
f
"ASR Engine use the
{
self
.
model_type
}
to process"
)
logger
.
info
(
"Create the preprocess instance"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocessing
=
Transformation
(
preprocess_conf
)
logger
.
info
(
"Read the audio file"
)
logger
.
info
(
f
"audio shape:
{
samples
.
shape
}
"
)
# fbank
x_chunk
=
preprocessing
(
samples
,
**
preprocess_args
)
x_chunk_lens
=
paddle
.
to_tensor
(
x_chunk
.
shape
[
0
])
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
logger
.
info
(
f
"process the audio feature success, feat shape:
{
x_chunk
.
shape
}
"
)
return
x_chunk
,
x_chunk_lens
class
ASREngine
(
BaseEngine
):
...
...
@@ -310,7 +395,10 @@ class ASREngine(BaseEngine):
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
def
preprocess
(
self
,
samples
,
sample_rate
):
def
preprocess
(
self
,
samples
,
sample_rate
,
model_type
=
"deepspeech2online_aishell-zh-16k"
):
"""preprocess
Args:
...
...
@@ -321,6 +409,7 @@ class ASREngine(BaseEngine):
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# if "deepspeech" in model_type:
x_chunk
,
x_chunk_lens
=
self
.
executor
.
extract_feat
(
samples
,
sample_rate
)
return
x_chunk
,
x_chunk_lens
...
...
paddlespeech/server/tests/asr/online/websocket_client.py
浏览文件 @
d21ccd02
...
...
@@ -103,7 +103,7 @@ class ASRAudioHandler:
def
main
(
args
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
info
(
"asr websocket client start"
)
handler
=
ASRAudioHandler
(
"127.0.0.1"
,
809
0
)
handler
=
ASRAudioHandler
(
"127.0.0.1"
,
809
6
)
loop
=
asyncio
.
get_event_loop
()
# support to process single audio file
...
...
paddlespeech/server/ws/asr_socket.py
浏览文件 @
d21ccd02
...
...
@@ -14,6 +14,7 @@
import
json
import
numpy
as
np
import
json
from
fastapi
import
APIRouter
from
fastapi
import
WebSocket
from
fastapi
import
WebSocketDisconnect
...
...
@@ -28,7 +29,7 @@ router = APIRouter()
@
router
.
websocket
(
'/ws/asr'
)
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
print
(
"websocket protocal receive the dataset"
)
await
websocket
.
accept
()
engine_pool
=
get_engine_pool
()
...
...
@@ -36,14 +37,18 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer
chunk_buffer_conf
=
asr_engine
.
config
.
chunk_buffer_conf
chunk_buffer
=
ChunkBuffer
(
frame_duration_ms
=
chunk_buffer_conf
[
'frame_duration_ms'
],
sample_rate
=
chunk_buffer_conf
[
'sample_rate'
],
sample_width
=
chunk_buffer_conf
[
'sample_width'
])
# init vad
vad_conf
=
asr_engine
.
config
.
vad_conf
vad
=
VADAudio
(
aggressiveness
=
vad_conf
[
'aggressiveness'
],
rate
=
vad_conf
[
'sample_rate'
],
frame_duration_ms
=
vad_conf
[
'frame_duration_ms'
])
# print(asr_engine.config)
# print(type(asr_engine.config))
vad_conf
=
asr_engine
.
config
.
get
(
'vad_conf'
,
None
)
if
vad_conf
:
vad
=
VADAudio
(
aggressiveness
=
vad_conf
[
'aggressiveness'
],
rate
=
vad_conf
[
'sample_rate'
],
frame_duration_ms
=
vad_conf
[
'frame_duration_ms'
])
try
:
while
True
:
...
...
@@ -65,7 +70,7 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
# reset single engine for an new connection
asr_engine
.
reset
()
#
asr_engine.reset()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
}
await
websocket
.
send_json
(
resp
)
break
...
...
@@ -75,16 +80,16 @@ async def websocket_endpoint(websocket: WebSocket):
elif
"bytes"
in
message
:
message
=
message
[
"bytes"
]
# vad for input bytes audio
vad
.
add_audio
(
message
)
message
=
b
''
.
join
(
f
for
f
in
vad
.
vad_collector
()
if
f
is
not
None
)
# # vad for input bytes audio
# vad.add_audio(message)
# message = b''.join(f for f in vad.vad_collector()
# if f is not None)
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_results
=
""
frames
=
chunk_buffer
.
frame_generator
(
message
)
for
frame
in
frames
:
# get the pcm data from the bytes
samples
=
np
.
frombuffer
(
frame
.
bytes
,
dtype
=
np
.
int16
)
sample_rate
=
asr_engine
.
config
.
sample_rate
x_chunk
,
x_chunk_lens
=
asr_engine
.
preprocess
(
samples
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录