Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d21ccd02
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d21ccd02
编写于
4月 15, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add conformer online server, test=doc
上级
af484fc9
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
272 addition
and
113 deletion
+272
-113
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+40
-16
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+6
-2
paddlespeech/s2t/modules/ctc.py
paddlespeech/s2t/modules/ctc.py
+2
-1
paddlespeech/s2t/modules/encoder.py
paddlespeech/s2t/modules/encoder.py
+2
-0
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+44
-10
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+160
-71
paddlespeech/server/tests/asr/online/websocket_client.py
paddlespeech/server/tests/asr/online/websocket_client.py
+1
-1
paddlespeech/server/ws/asr_socket.py
paddlespeech/server/ws/asr_socket.py
+17
-12
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
d21ccd02
...
@@ -91,6 +91,20 @@ pretrained_models = {
...
@@ -91,6 +91,20 @@ pretrained_models = {
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'lm_md5'
:
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer2online_aishell-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz'
,
'md5'
:
'4814e52e0fc2fd48899373f95c84b0c9'
,
'cfg_path'
:
'config.yaml'
,
'ckpt_path'
:
'exp/deepspeech2_online/checkpoints/avg_30'
,
'lm_url'
:
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2offline_librispeech-en-16k"
:
{
"deepspeech2offline_librispeech-en-16k"
:
{
'url'
:
'url'
:
...
@@ -115,6 +129,8 @@ model_alias = {
...
@@ -115,6 +129,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"
,
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"
,
"conformer"
:
"conformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"paddlespeech.s2t.models.u2:U2Model"
,
"conformer2online"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"transformer"
:
"transformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"paddlespeech.s2t.models.u2:U2Model"
,
"wenetspeech"
:
"wenetspeech"
:
...
@@ -219,6 +235,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -219,6 +235,7 @@ class ASRExecutor(BaseExecutor):
"""
"""
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
logger
.
info
(
"start to init the model"
)
if
hasattr
(
self
,
'model'
):
if
hasattr
(
self
,
'model'
):
logger
.
info
(
'Model had been initialized.'
)
logger
.
info
(
'Model had been initialized.'
)
return
return
...
@@ -233,14 +250,15 @@ class ASRExecutor(BaseExecutor):
...
@@ -233,14 +250,15 @@ class ASRExecutor(BaseExecutor):
self
.
ckpt_path
=
os
.
path
.
join
(
self
.
ckpt_path
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'ckpt_path'
]
+
".pdparams"
)
res_path
,
pretrained_models
[
tag
][
'ckpt_path'
]
+
".pdparams"
)
logger
.
info
(
res_path
)
logger
.
info
(
res_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
else
:
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
+
".pdparams"
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
+
".pdparams"
)
self
.
res_path
=
os
.
path
.
dirname
(
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
#Init body.
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
...
@@ -269,7 +287,6 @@ class ASRExecutor(BaseExecutor):
...
@@ -269,7 +287,6 @@ class ASRExecutor(BaseExecutor):
vocab
=
self
.
config
.
vocab_filepath
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
config
.
decode
.
decoding_method
=
decode_method
self
.
config
.
decode
.
decoding_method
=
decode_method
else
:
else
:
raise
Exception
(
"wrong type"
)
raise
Exception
(
"wrong type"
)
model_name
=
model_type
[:
model_type
.
rindex
(
model_name
=
model_type
[:
model_type
.
rindex
(
...
@@ -347,12 +364,14 @@ class ASRExecutor(BaseExecutor):
...
@@ -347,12 +364,14 @@ class ASRExecutor(BaseExecutor):
else
:
else
:
raise
Exception
(
"wrong type"
)
raise
Exception
(
"wrong type"
)
logger
.
info
(
"audio feat process success"
)
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
infer
(
self
,
model_type
:
str
):
def
infer
(
self
,
model_type
:
str
):
"""
"""
Model inference and result stored in self.output.
Model inference and result stored in self.output.
"""
"""
logger
.
info
(
"start to infer the model to get the output"
)
cfg
=
self
.
config
.
decode
cfg
=
self
.
config
.
decode
audio
=
self
.
_inputs
[
"audio"
]
audio
=
self
.
_inputs
[
"audio"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
...
@@ -369,17 +388,22 @@ class ASRExecutor(BaseExecutor):
...
@@ -369,17 +388,22 @@ class ASRExecutor(BaseExecutor):
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
result_transcripts
=
self
.
model
.
decode
(
logger
.
info
(
f
"we will use the transformer like model :
{
model_type
}
"
)
audio
,
try
:
audio_len
,
result_transcripts
=
self
.
model
.
decode
(
text_feature
=
self
.
text_feature
,
audio
,
decoding_method
=
cfg
.
decoding_method
,
audio_len
,
beam_size
=
cfg
.
beam_size
,
text_feature
=
self
.
text_feature
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_method
=
cfg
.
decoding_method
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
beam_size
=
cfg
.
beam_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
ctc_weight
=
cfg
.
ctc_weight
,
simulate_streaming
=
cfg
.
simulate_streaming
)
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
][
0
]
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
][
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
else
:
raise
Exception
(
"invalid model name"
)
raise
Exception
(
"invalid model name"
)
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
d21ccd02
...
@@ -213,12 +213,14 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -213,12 +213,14 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks
=
num_decoding_left_chunks
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
)
# (B, maxlen, encoder_dim)
else
:
else
:
print
(
"offline decode from the asr"
)
encoder_out
,
encoder_mask
=
self
.
encoder
(
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech
,
speech_lengths
,
speech_lengths
,
decoding_chunk_size
=
decoding_chunk_size
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
)
# (B, maxlen, encoder_dim)
print
(
"offline decode success"
)
return
encoder_out
,
encoder_mask
return
encoder_out
,
encoder_mask
def
recognize
(
def
recognize
(
...
@@ -706,13 +708,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
...
@@ -706,13 +708,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
List[List[int]]: transcripts.
List[List[int]]: transcripts.
"""
"""
batch_size
=
feats
.
shape
[
0
]
batch_size
=
feats
.
shape
[
0
]
print
(
"start to decode the audio feat"
)
if
decoding_method
in
[
'ctc_prefix_beam_search'
,
if
decoding_method
in
[
'ctc_prefix_beam_search'
,
'attention_rescoring'
]
and
batch_size
>
1
:
'attention_rescoring'
]
and
batch_size
>
1
:
logger
.
fatal
(
logger
.
error
(
f
'decoding mode
{
decoding_method
}
must be running with batch_size == 1'
f
'decoding mode
{
decoding_method
}
must be running with batch_size == 1'
)
)
logger
.
error
(
f
"current batch_size is
{
batch_size
}
"
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
print
(
f
"use the
{
decoding_method
}
to decode the audio feat"
)
if
decoding_method
==
'attention'
:
if
decoding_method
==
'attention'
:
hyps
=
self
.
recognize
(
hyps
=
self
.
recognize
(
feats
,
feats
,
...
...
paddlespeech/s2t/modules/ctc.py
浏览文件 @
d21ccd02
...
@@ -180,7 +180,8 @@ class CTCDecoder(CTCDecoderBase):
...
@@ -180,7 +180,8 @@ class CTCDecoder(CTCDecoderBase):
# init once
# init once
if
self
.
_ext_scorer
is
not
None
:
if
self
.
_ext_scorer
is
not
None
:
return
return
from
paddlespeech.s2t.decoders.ctcdecoder
import
Scorer
# noqa: F401
if
language_model_path
!=
''
:
if
language_model_path
!=
''
:
logger
.
info
(
"begin to initialize the external scorer "
logger
.
info
(
"begin to initialize the external scorer "
"for decoding"
)
"for decoding"
)
...
...
paddlespeech/s2t/modules/encoder.py
浏览文件 @
d21ccd02
...
@@ -317,6 +317,8 @@ class BaseEncoder(nn.Layer):
...
@@ -317,6 +317,8 @@ class BaseEncoder(nn.Layer):
outputs
=
[]
outputs
=
[]
offset
=
0
offset
=
0
# Feed forward overlap input step by step
# Feed forward overlap input step by step
print
(
f
"context:
{
context
}
"
)
print
(
f
"stride:
{
stride
}
"
)
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
end
=
min
(
cur
+
decoding_window
,
num_frames
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
chunk_xs
=
xs
[:,
cur
:
end
,
:]
...
...
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
d21ccd02
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
# SERVER SETTING #
# SERVER SETTING #
#################################################################################
#################################################################################
host
:
0.0.0.0
host
:
0.0.0.0
port
:
809
1
port
:
809
6
# The task format in the engin_list is: <speech task>_<engine type>
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# task choices = ['asr_online', 'tts_online']
...
@@ -18,10 +18,44 @@ engine_list: ['asr_online']
...
@@ -18,10 +18,44 @@ engine_list: ['asr_online']
# ENGINE CONFIG #
# ENGINE CONFIG #
#################################################################################
#################################################################################
# ################################### ASR #########################################
# ################### speech task: asr; engine_type: online #######################
# asr_online:
# model_type: 'deepspeech2online_aishell'
# am_model: # the pdmodel file of am static model [optional]
# am_params: # the pdiparams file of am static model [optional]
# lang: 'zh'
# sample_rate: 16000
# cfg_path:
# decode_method:
# force_yes: True
# am_predictor_conf:
# device: # set 'gpu:id' or 'cpu'
# switch_ir_optim: True
# glog_info: False # True -> print glog
# summary: True # False -> do not show predictor config
# chunk_buffer_conf:
# frame_duration_ms: 80
# shift_ms: 40
# sample_rate: 16000
# sample_width: 2
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
################################### ASR #########################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
################### speech task: asr; engine_type: online #######################
asr_online
:
asr_online
:
model_type
:
'
deepspeech
2online_aishell'
model_type
:
'
conformer
2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
lang
:
'
zh'
...
@@ -37,15 +71,15 @@ asr_online:
...
@@ -37,15 +71,15 @@ asr_online:
summary
:
True
# False -> do not show predictor config
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
chunk_buffer_conf
:
frame_duration_ms
:
8
0
frame_duration_ms
:
8
5
shift_ms
:
40
shift_ms
:
40
sample_rate
:
16000
sample_rate
:
16000
sample_width
:
2
sample_width
:
2
vad_conf
:
#
vad_conf:
aggressiveness
:
2
#
aggressiveness: 2
sample_rate
:
16000
#
sample_rate: 16000
frame_duration_ms
:
20
#
frame_duration_ms: 20
sample_width
:
2
#
sample_width: 2
padding_ms
:
200
#
padding_ms: 200
padding_ratio
:
0.9
#
padding_ratio: 0.9
\ No newline at end of file
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
d21ccd02
...
@@ -20,11 +20,15 @@ from numpy import float32
...
@@ -20,11 +20,15 @@ from numpy import float32
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
model_alias
from
paddlespeech.cli.asr.infer
import
pretrained_models
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.speech
import
SpeechSegment
from
paddlespeech.s2t.frontend.speech
import
SpeechSegment
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.audio_process
import
pcm2float
...
@@ -51,6 +55,24 @@ pretrained_models = {
...
@@ -51,6 +55,24 @@ pretrained_models = {
'lm_md5'
:
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"conformer2online_aishell-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz'
,
'md5'
:
'4814e52e0fc2fd48899373f95c84b0c9'
,
'cfg_path'
:
'exp/chunk_conformer//conf/config.yaml'
,
'ckpt_path'
:
'exp/chunk_conformer/checkpoints/avg_30/'
,
'model'
:
'exp/chunk_conformer/checkpoints/avg_30.pdparams'
,
'params'
:
'exp/chunk_conformer/checkpoints/avg_30.pdparams'
,
'lm_url'
:
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
}
}
...
@@ -71,15 +93,17 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -71,15 +93,17 @@ class ASRServerExecutor(ASRExecutor):
"""
"""
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
self
.
model_type
=
model_type
self
.
sample_rate
=
sample_rate
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
res_path
=
res_path
self
.
res_path
=
res_path
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
self
.
cfg_path
=
"/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml"
pretrained_models
[
tag
][
'cfg_path'
])
# self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path'])
self
.
am_model
=
os
.
path
.
join
(
res_path
,
self
.
am_model
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'model'
])
pretrained_models
[
tag
][
'model'
])
...
@@ -119,49 +143,67 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -119,49 +143,67 @@ class ASRServerExecutor(ASRExecutor):
lm_url
,
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
# 开发 conformer 的流式模型
logger
.
info
(
"start to create the stream conformer asr engine"
)
logger
.
info
(
"start to create the stream conformer asr engine"
)
# 复用cli里面的代码
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
self
.
config
.
vocab_filepath
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
vocab_filepath
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
# update the decoding method
if
decode_method
:
self
.
config
.
decode
.
decoding_method
=
decode_method
else
:
else
:
raise
Exception
(
"wrong type"
)
raise
Exception
(
"wrong type"
)
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
# AM predictor
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
predictor_conf
=
self
.
am_predictor_conf
)
# decoder
# decoder
logger
.
info
(
"ASR engine start to create the ctc decoder instance"
)
logger
.
info
(
"ASR engine start to create the ctc decoder instance"
)
self
.
decoder
=
CTCDecoder
(
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
config
.
output_dim
,
# <blank> is in vocab
odim
=
self
.
config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
config
.
rnn_layer_size
*
2
,
enc_n_units
=
self
.
config
.
rnn_layer_size
*
2
,
blank_id
=
self
.
config
.
blank_id
,
blank_id
=
self
.
config
.
blank_id
,
dropout_rate
=
0.0
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
self
.
config
.
get
(
'ctc_grad_norm_type'
,
None
))
grad_norm_type
=
self
.
config
.
get
(
'ctc_grad_norm_type'
,
None
))
# init decoder
# init decoder
logger
.
info
(
"ASR engine start to init the ctc decoder"
)
logger
.
info
(
"ASR engine start to init the ctc decoder"
)
cfg
=
self
.
config
.
decode
cfg
=
self
.
config
.
decode
decode_batch_size
=
1
# for online
decode_batch_size
=
1
# for online
self
.
decoder
.
init_decoder
(
self
.
decoder
.
init_decoder
(
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
cfg
.
num_proc_bsearch
)
# init state box
# init state box
self
.
chunk_state_h_box
=
np
.
zeros
(
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
dtype
=
float32
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
logger
.
info
(
f
"model name:
{
model_name
}
"
)
model_class
=
dynamic_import
(
model_name
,
model_alias
)
model_conf
=
self
.
config
model
=
model_class
.
from_config
(
model_conf
)
self
.
model
=
model
logger
.
info
(
"create the transformer like model success"
)
def
reset_decoder_and_chunk
(
self
):
def
reset_decoder_and_chunk
(
self
):
"""reset decoder and chunk state for an new audio
"""reset decoder and chunk state for an new audio
...
@@ -186,6 +228,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -186,6 +228,7 @@ class ASRServerExecutor(ASRExecutor):
Returns:
Returns:
[type]: [description]
[type]: [description]
"""
"""
logger
.
info
(
"start to decoce chunk by chunk"
)
if
"deepspeech2online"
in
model_type
:
if
"deepspeech2online"
in
model_type
:
input_names
=
self
.
am_predictor
.
get_input_names
()
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
...
@@ -224,10 +267,29 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -224,10 +267,29 @@ class ASRServerExecutor(ASRExecutor):
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
return
trans_best
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
raise
Exception
(
"invalid model name"
)
try
:
logger
.
info
(
f
"we will use the transformer like model :
{
self
.
model_type
}
"
)
cfg
=
self
.
config
.
decode
result_transcripts
=
self
.
model
.
decode
(
x_chunk
,
x_chunk_lens
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
beam_size
=
cfg
.
beam_size
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
return
result_transcripts
[
0
][
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
else
:
raise
Exception
(
"invalid model name"
)
raise
Exception
(
"invalid model name"
)
...
@@ -244,32 +306,55 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -244,32 +306,55 @@ class ASRServerExecutor(ASRExecutor):
"""
"""
# pcm16 -> pcm 32
# pcm16 -> pcm 32
samples
=
pcm2float
(
samples
)
samples
=
pcm2float
(
samples
)
if
"deepspeech2online"
in
self
.
model_type
:
# read audio
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
sample_rate
,
transcript
=
" "
)
samples
,
sample_rate
,
transcript
=
" "
)
# audio augment
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
audio_len
=
audio
.
shape
[
0
]
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
x_chunk
=
audio
.
numpy
()
x_chunk_lens
=
np
.
array
([
audio_len
])
x_chunk
=
audio
.
numpy
()
x_chunk_lens
=
np
.
array
([
audio_len
])
return
x_chunk
,
x_chunk_lens
return
x_chunk
,
x_chunk_lens
elif
"conformer2online"
in
self
.
model_type
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
\
"the model sample_rate is {self.sample_rate}"
)
logger
.
info
(
f
"ASR Engine use the
{
self
.
model_type
}
to process"
)
logger
.
info
(
"Create the preprocess instance"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocessing
=
Transformation
(
preprocess_conf
)
logger
.
info
(
"Read the audio file"
)
logger
.
info
(
f
"audio shape:
{
samples
.
shape
}
"
)
# fbank
x_chunk
=
preprocessing
(
samples
,
**
preprocess_args
)
x_chunk_lens
=
paddle
.
to_tensor
(
x_chunk
.
shape
[
0
])
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
logger
.
info
(
f
"process the audio feature success, feat shape:
{
x_chunk
.
shape
}
"
)
return
x_chunk
,
x_chunk_lens
class
ASREngine
(
BaseEngine
):
class
ASREngine
(
BaseEngine
):
...
@@ -310,7 +395,10 @@ class ASREngine(BaseEngine):
...
@@ -310,7 +395,10 @@ class ASREngine(BaseEngine):
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
return
True
def
preprocess
(
self
,
samples
,
sample_rate
):
def
preprocess
(
self
,
samples
,
sample_rate
,
model_type
=
"deepspeech2online_aishell-zh-16k"
):
"""preprocess
"""preprocess
Args:
Args:
...
@@ -321,6 +409,7 @@ class ASREngine(BaseEngine):
...
@@ -321,6 +409,7 @@ class ASREngine(BaseEngine):
x_chunk (numpy.array): shape[B, T, D]
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
x_chunk_lens (numpy.array): shape[B]
"""
"""
# if "deepspeech" in model_type:
x_chunk
,
x_chunk_lens
=
self
.
executor
.
extract_feat
(
samples
,
sample_rate
)
x_chunk
,
x_chunk_lens
=
self
.
executor
.
extract_feat
(
samples
,
sample_rate
)
return
x_chunk
,
x_chunk_lens
return
x_chunk
,
x_chunk_lens
...
...
paddlespeech/server/tests/asr/online/websocket_client.py
浏览文件 @
d21ccd02
...
@@ -103,7 +103,7 @@ class ASRAudioHandler:
...
@@ -103,7 +103,7 @@ class ASRAudioHandler:
def
main
(
args
):
def
main
(
args
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
info
(
"asr websocket client start"
)
logging
.
info
(
"asr websocket client start"
)
handler
=
ASRAudioHandler
(
"127.0.0.1"
,
809
0
)
handler
=
ASRAudioHandler
(
"127.0.0.1"
,
809
6
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
# support to process single audio file
# support to process single audio file
...
...
paddlespeech/server/ws/asr_socket.py
浏览文件 @
d21ccd02
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
json
import
json
import
numpy
as
np
import
numpy
as
np
import
json
from
fastapi
import
APIRouter
from
fastapi
import
APIRouter
from
fastapi
import
WebSocket
from
fastapi
import
WebSocket
from
fastapi
import
WebSocketDisconnect
from
fastapi
import
WebSocketDisconnect
...
@@ -28,7 +29,7 @@ router = APIRouter()
...
@@ -28,7 +29,7 @@ router = APIRouter()
@
router
.
websocket
(
'/ws/asr'
)
@
router
.
websocket
(
'/ws/asr'
)
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
print
(
"websocket protocal receive the dataset"
)
await
websocket
.
accept
()
await
websocket
.
accept
()
engine_pool
=
get_engine_pool
()
engine_pool
=
get_engine_pool
()
...
@@ -36,14 +37,18 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -36,14 +37,18 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer
# init buffer
chunk_buffer_conf
=
asr_engine
.
config
.
chunk_buffer_conf
chunk_buffer_conf
=
asr_engine
.
config
.
chunk_buffer_conf
chunk_buffer
=
ChunkBuffer
(
chunk_buffer
=
ChunkBuffer
(
frame_duration_ms
=
chunk_buffer_conf
[
'frame_duration_ms'
],
sample_rate
=
chunk_buffer_conf
[
'sample_rate'
],
sample_rate
=
chunk_buffer_conf
[
'sample_rate'
],
sample_width
=
chunk_buffer_conf
[
'sample_width'
])
sample_width
=
chunk_buffer_conf
[
'sample_width'
])
# init vad
# init vad
vad_conf
=
asr_engine
.
config
.
vad_conf
# print(asr_engine.config)
vad
=
VADAudio
(
# print(type(asr_engine.config))
aggressiveness
=
vad_conf
[
'aggressiveness'
],
vad_conf
=
asr_engine
.
config
.
get
(
'vad_conf'
,
None
)
rate
=
vad_conf
[
'sample_rate'
],
if
vad_conf
:
frame_duration_ms
=
vad_conf
[
'frame_duration_ms'
])
vad
=
VADAudio
(
aggressiveness
=
vad_conf
[
'aggressiveness'
],
rate
=
vad_conf
[
'sample_rate'
],
frame_duration_ms
=
vad_conf
[
'frame_duration_ms'
])
try
:
try
:
while
True
:
while
True
:
...
@@ -65,7 +70,7 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -65,7 +70,7 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool
=
get_engine_pool
()
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_engine
=
engine_pool
[
'asr'
]
# reset single engine for an new connection
# reset single engine for an new connection
asr_engine
.
reset
()
#
asr_engine.reset()
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
}
resp
=
{
"status"
:
"ok"
,
"signal"
:
"finished"
}
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
break
break
...
@@ -75,16 +80,16 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -75,16 +80,16 @@ async def websocket_endpoint(websocket: WebSocket):
elif
"bytes"
in
message
:
elif
"bytes"
in
message
:
message
=
message
[
"bytes"
]
message
=
message
[
"bytes"
]
# vad for input bytes audio
# # vad for input bytes audio
vad
.
add_audio
(
message
)
# vad.add_audio(message)
message
=
b
''
.
join
(
f
for
f
in
vad
.
vad_collector
()
# message = b''.join(f for f in vad.vad_collector()
if
f
is
not
None
)
# if f is not None)
engine_pool
=
get_engine_pool
()
engine_pool
=
get_engine_pool
()
asr_engine
=
engine_pool
[
'asr'
]
asr_engine
=
engine_pool
[
'asr'
]
asr_results
=
""
asr_results
=
""
frames
=
chunk_buffer
.
frame_generator
(
message
)
frames
=
chunk_buffer
.
frame_generator
(
message
)
for
frame
in
frames
:
for
frame
in
frames
:
# get the pcm data from the bytes
samples
=
np
.
frombuffer
(
frame
.
bytes
,
dtype
=
np
.
int16
)
samples
=
np
.
frombuffer
(
frame
.
bytes
,
dtype
=
np
.
int16
)
sample_rate
=
asr_engine
.
config
.
sample_rate
sample_rate
=
asr_engine
.
config
.
sample_rate
x_chunk
,
x_chunk_lens
=
asr_engine
.
preprocess
(
samples
,
x_chunk
,
x_chunk_lens
=
asr_engine
.
preprocess
(
samples
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录