Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
380afbbc
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
380afbbc
编写于
4月 19, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ds2 model multi session, test=doc
上级
5acb0b52
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
263 addition
and
56 deletion
+263
-56
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+13
-37
paddlespeech/server/conf/ws_conformer_application.yaml
paddlespeech/server/conf/ws_conformer_application.yaml
+45
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+205
-19
未找到文件。
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
380afbbc
...
@@ -18,44 +18,10 @@ engine_list: ['asr_online']
...
@@ -18,44 +18,10 @@ engine_list: ['asr_online']
# ENGINE CONFIG #
# ENGINE CONFIG #
#################################################################################
#################################################################################
# ################################### ASR #########################################
# ################### speech task: asr; engine_type: online #######################
# asr_online:
# model_type: 'deepspeech2online_aishell'
# am_model: # the pdmodel file of am static model [optional]
# am_params: # the pdiparams file of am static model [optional]
# lang: 'zh'
# sample_rate: 16000
# cfg_path:
# decode_method:
# force_yes: True
# am_predictor_conf:
# device: # set 'gpu:id' or 'cpu'
# switch_ir_optim: True
# glog_info: False # True -> print glog
# summary: True # False -> do not show predictor config
# chunk_buffer_conf:
# frame_duration_ms: 80
# shift_ms: 40
# sample_rate: 16000
# sample_width: 2
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
################################### ASR #########################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
################### speech task: asr; engine_type: online #######################
asr_online
:
asr_online
:
model_type
:
'
conformer
2online_aishell'
model_type
:
'
deepspeech
2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
lang
:
'
zh'
...
@@ -71,9 +37,19 @@ asr_online:
...
@@ -71,9 +37,19 @@ asr_online:
summary
:
True
# False -> do not show predictor config
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
chunk_buffer_conf
:
frame_duration_ms
:
80
shift_ms
:
40
sample_rate
:
16000
sample_width
:
2
window_n
:
7
# frame
window_n
:
7
# frame
shift_n
:
4
# frame
shift_n
:
4
# frame
window_ms
:
2
5
# ms
window_ms
:
2
0
# ms
shift_ms
:
10
# ms
shift_ms
:
10
# ms
vad_conf
:
aggressiveness
:
2
sample_rate
:
16000
sample_rate
:
16000
sample_width
:
2
frame_duration_ms
:
20
\ No newline at end of file
sample_width
:
2
padding_ms
:
200
padding_ratio
:
0.9
\ No newline at end of file
paddlespeech/server/conf/ws_conformer_application.yaml
0 → 100644
浏览文件 @
380afbbc
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host
:
0.0.0.0
port
:
8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol
:
'
websocket'
engine_list
:
[
'
asr_online'
]
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online
:
model_type
:
'
conformer2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
sample_rate
:
16000
cfg_path
:
decode_method
:
force_yes
:
True
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
glog_info
:
False
# True -> print glog
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
window_n
:
7
# frame
shift_n
:
4
# frame
window_ms
:
25
# ms
shift_ms
:
10
# ms
sample_rate
:
16000
sample_width
:
2
\ No newline at end of file
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
380afbbc
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
from
typing
import
Optional
from
typing
import
Optional
import
copy
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
numpy
import
float32
from
numpy
import
float32
...
@@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler:
...
@@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler:
)
)
self
.
config
=
asr_engine
.
config
self
.
config
=
asr_engine
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model
=
asr_engine
.
executor
.
model
#
self.model = asr_engine.executor.model
self
.
asr_engine
=
asr_engine
self
.
asr_engine
=
asr_engine
self
.
init
()
self
.
init
()
...
@@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler:
...
@@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler:
def
init
(
self
):
def
init
(
self
):
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
pass
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
self
.
model_config
)
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
model_config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
model_config
.
rnn_layer_size
*
2
,
blank_id
=
self
.
model_config
.
blank_id
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
self
.
model_config
.
get
(
'ctc_grad_norm_type'
,
None
))
cfg
=
self
.
model_config
.
decode
decode_batch_size
=
1
# for online
self
.
decoder
.
init_decoder
(
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# frame window samples length and frame shift samples length
self
.
win_length
=
int
(
self
.
model_config
.
window_ms
*
self
.
sample_rate
)
self
.
n_shift
=
int
(
self
.
model_config
.
stride_ms
*
self
.
sample_rate
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
...
@@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler:
...
@@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler:
def
extract_feat
(
self
,
samples
):
def
extract_feat
(
self
,
samples
):
if
"deepspeech2online"
in
self
.
model_type
:
if
"deepspeech2online"
in
self
.
model_type
:
pass
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
# pcm16 -> pcm 32
samples
=
pcm2float
(
self
.
remained_wav
)
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
self
.
sample_rate
,
transcript
=
" "
)
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
audio
else
:
assert
(
len
(
audio
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
audio
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
self
.
num_frames
+=
audio_len
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
audio_len
:]
logger
.
info
(
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
elif
"conformer2online"
in
self
.
model_type
:
elif
"conformer2online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
...
@@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler:
...
@@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler:
# logger.info(f"accumulate samples: {self.num_samples}")
# logger.info(f"accumulate samples: {self.num_samples}")
def
reset
(
self
):
def
reset
(
self
):
self
.
subsampling_cache
=
None
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
self
.
elayers_output_cache
=
None
# for deepspeech2
self
.
conformer_cnn_cache
=
None
self
.
chunk_state_h_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_h_box
)
self
.
encoder_out
=
None
self
.
chunk_state_c_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_c_box
)
self
.
cached_feat
=
None
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
self
.
remained_wav
=
None
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
self
.
offset
=
0
# for conformer online
self
.
num_samples
=
0
self
.
subsampling_cache
=
None
self
.
device
=
None
self
.
elayers_output_cache
=
None
self
.
hyps
=
[]
self
.
conformer_cnn_cache
=
None
self
.
num_frames
=
0
self
.
encoder_out
=
None
self
.
chunk_num
=
0
self
.
cached_feat
=
None
self
.
global_frame_offset
=
0
self
.
remained_wav
=
None
self
.
result_transcripts
=
[
''
]
self
.
offset
=
0
self
.
num_samples
=
0
self
.
device
=
None
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
def
decode
(
self
,
is_finished
=
False
):
def
decode
(
self
,
is_finished
=
False
):
if
"deepspeech2online"
in
self
.
model_type
:
if
"deepspeech2online"
in
self
.
model_type
:
pass
# x_chunk 是特征数据
decoding_chunk_size
=
1
# decoding_chunk_size=1 in deepspeech2 model
context
=
7
# context=7 in deepspeech2 model
subsampling
=
4
# subsampling=4 in deepspeech2 model
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
# if get the finished chunk, we need process the last context
left_frames
=
context
else
:
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
x_chunk
=
self
.
cached_feat
[:,
cur
:
end
,
:].
numpy
()
x_chunk_lens
=
np
.
array
([
x_chunk
.
shape
[
1
]])
trans_best
=
self
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
)
self
.
result_transcripts
=
[
trans_best
]
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
# return trans_best[0]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
try
:
try
:
logger
.
info
(
logger
.
info
(
...
@@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler:
...
@@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler:
else
:
else
:
raise
Exception
(
"invalid model name"
)
raise
Exception
(
"invalid model name"
)
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
):
logger
.
info
(
"start to decoce one chunk with deepspeech2 model"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
3
])
audio_handle
.
reshape
(
x_chunk
.
shape
)
audio_handle
.
copy_from_cpu
(
x_chunk
)
audio_len_handle
.
reshape
(
x_chunk_lens
.
shape
)
audio_len_handle
.
copy_from_cpu
(
x_chunk_lens
)
h_box_handle
.
reshape
(
self
.
chunk_state_h_box
.
shape
)
h_box_handle
.
copy_from_cpu
(
self
.
chunk_state_h_box
)
c_box_handle
.
reshape
(
self
.
chunk_state_c_box
.
shape
)
c_box_handle
.
copy_from_cpu
(
self
.
chunk_state_c_box
)
output_names
=
self
.
am_predictor
.
get_output_names
()
output_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
1
])
output_state_h_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
2
])
output_state_c_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
3
])
self
.
am_predictor
.
run
()
output_chunk_probs
=
output_handle
.
copy_to_cpu
()
output_chunk_lens
=
output_lens_handle
.
copy_to_cpu
()
self
.
chunk_state_h_box
=
output_state_h_handle
.
copy_to_cpu
()
self
.
chunk_state_c_box
=
output_state_c_handle
.
copy_to_cpu
()
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
def
advance_decoding
(
self
,
is_finished
=
False
):
def
advance_decoding
(
self
,
is_finished
=
False
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
logger
.
info
(
"start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
cfg
=
self
.
ctc_decode_config
...
@@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler:
...
@@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler:
)
)
return
None
,
None
return
None
,
None
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
if
num_frames
<
context
:
logger
.
info
(
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
...
@@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler:
...
@@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler:
return
''
return
''
def
rescoring
(
self
):
def
rescoring
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
return
logger
.
info
(
"rescoring the final result"
)
logger
.
info
(
"rescoring the final result"
)
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
return
return
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录