Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
380afbbc
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
380afbbc
编写于
4月 19, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ds2 model multi session, test=doc
上级
5acb0b52
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
263 addition
and
56 deletion
+263
-56
paddlespeech/server/conf/ws_application.yaml
paddlespeech/server/conf/ws_application.yaml
+13
-37
paddlespeech/server/conf/ws_conformer_application.yaml
paddlespeech/server/conf/ws_conformer_application.yaml
+45
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+205
-19
未找到文件。
paddlespeech/server/conf/ws_application.yaml
浏览文件 @
380afbbc
...
...
@@ -18,44 +18,10 @@ engine_list: ['asr_online']
# ENGINE CONFIG #
#################################################################################
# ################################### ASR #########################################
# ################### speech task: asr; engine_type: online #######################
# asr_online:
# model_type: 'deepspeech2online_aishell'
# am_model: # the pdmodel file of am static model [optional]
# am_params: # the pdiparams file of am static model [optional]
# lang: 'zh'
# sample_rate: 16000
# cfg_path:
# decode_method:
# force_yes: True
# am_predictor_conf:
# device: # set 'gpu:id' or 'cpu'
# switch_ir_optim: True
# glog_info: False # True -> print glog
# summary: True # False -> do not show predictor config
# chunk_buffer_conf:
# frame_duration_ms: 80
# shift_ms: 40
# sample_rate: 16000
# sample_width: 2
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online
:
model_type
:
'
conformer
2online_aishell'
model_type
:
'
deepspeech
2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
...
...
@@ -71,9 +37,19 @@ asr_online:
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
frame_duration_ms
:
80
shift_ms
:
40
sample_rate
:
16000
sample_width
:
2
window_n
:
7
# frame
shift_n
:
4
# frame
window_ms
:
2
5
# ms
window_ms
:
2
0
# ms
shift_ms
:
10
# ms
vad_conf
:
aggressiveness
:
2
sample_rate
:
16000
sample_width
:
2
\ No newline at end of file
frame_duration_ms
:
20
sample_width
:
2
padding_ms
:
200
padding_ratio
:
0.9
\ No newline at end of file
paddlespeech/server/conf/ws_conformer_application.yaml
0 → 100644
浏览文件 @
380afbbc
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host
:
0.0.0.0
port
:
8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol
:
'
websocket'
engine_list
:
[
'
asr_online'
]
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online
:
model_type
:
'
conformer2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
sample_rate
:
16000
cfg_path
:
decode_method
:
force_yes
:
True
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
switch_ir_optim
:
True
glog_info
:
False
# True -> print glog
summary
:
True
# False -> do not show predictor config
chunk_buffer_conf
:
window_n
:
7
# frame
shift_n
:
4
# frame
window_ms
:
25
# ms
shift_ms
:
10
# ms
sample_rate
:
16000
sample_width
:
2
\ No newline at end of file
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
380afbbc
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
import
os
from
typing
import
Optional
import
copy
import
numpy
as
np
import
paddle
from
numpy
import
float32
...
...
@@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler:
)
self
.
config
=
asr_engine
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model
=
asr_engine
.
executor
.
model
#
self.model = asr_engine.executor.model
self
.
asr_engine
=
asr_engine
self
.
init
()
...
...
@@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler:
def
init
(
self
):
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
pass
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
self
.
model_config
)
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
model_config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
model_config
.
rnn_layer_size
*
2
,
blank_id
=
self
.
model_config
.
blank_id
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
self
.
model_config
.
get
(
'ctc_grad_norm_type'
,
None
))
cfg
=
self
.
model_config
.
decode
decode_batch_size
=
1
# for online
self
.
decoder
.
init_decoder
(
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# frame window samples length and frame shift samples length
self
.
win_length
=
int
(
self
.
model_config
.
window_ms
*
self
.
sample_rate
)
self
.
n_shift
=
int
(
self
.
model_config
.
stride_ms
*
self
.
sample_rate
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
...
...
@@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler:
def
extract_feat
(
self
,
samples
):
if
"deepspeech2online"
in
self
.
model_type
:
pass
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
# pcm16 -> pcm 32
samples
=
pcm2float
(
self
.
remained_wav
)
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
self
.
sample_rate
,
transcript
=
" "
)
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
audio
else
:
assert
(
len
(
audio
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
audio
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
self
.
num_frames
+=
audio_len
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
audio_len
:]
logger
.
info
(
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
elif
"conformer2online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
...
...
@@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler:
# logger.info(f"accumulate samples: {self.num_samples}")
def
reset
(
self
):
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
num_samples
=
0
self
.
device
=
None
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
# for deepspeech2
self
.
chunk_state_h_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_h_box
)
self
.
chunk_state_c_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_c_box
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
# for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
num_samples
=
0
self
.
device
=
None
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
def
decode
(
self
,
is_finished
=
False
):
if
"deepspeech2online"
in
self
.
model_type
:
pass
# x_chunk 是特征数据
decoding_chunk_size
=
1
# decoding_chunk_size=1 in deepspeech2 model
context
=
7
# context=7 in deepspeech2 model
subsampling
=
4
# subsampling=4 in deepspeech2 model
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
# if get the finished chunk, we need process the last context
left_frames
=
context
else
:
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
x_chunk
=
self
.
cached_feat
[:,
cur
:
end
,
:].
numpy
()
x_chunk_lens
=
np
.
array
([
x_chunk
.
shape
[
1
]])
trans_best
=
self
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
)
self
.
result_transcripts
=
[
trans_best
]
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
# return trans_best[0]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
try
:
logger
.
info
(
...
...
@@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler:
else
:
raise
Exception
(
"invalid model name"
)
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
):
logger
.
info
(
"start to decoce one chunk with deepspeech2 model"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
3
])
audio_handle
.
reshape
(
x_chunk
.
shape
)
audio_handle
.
copy_from_cpu
(
x_chunk
)
audio_len_handle
.
reshape
(
x_chunk_lens
.
shape
)
audio_len_handle
.
copy_from_cpu
(
x_chunk_lens
)
h_box_handle
.
reshape
(
self
.
chunk_state_h_box
.
shape
)
h_box_handle
.
copy_from_cpu
(
self
.
chunk_state_h_box
)
c_box_handle
.
reshape
(
self
.
chunk_state_c_box
.
shape
)
c_box_handle
.
copy_from_cpu
(
self
.
chunk_state_c_box
)
output_names
=
self
.
am_predictor
.
get_output_names
()
output_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
1
])
output_state_h_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
2
])
output_state_c_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
3
])
self
.
am_predictor
.
run
()
output_chunk_probs
=
output_handle
.
copy_to_cpu
()
output_chunk_lens
=
output_lens_handle
.
copy_to_cpu
()
self
.
chunk_state_h_box
=
output_state_h_handle
.
copy_to_cpu
()
self
.
chunk_state_c_box
=
output_state_c_handle
.
copy_to_cpu
()
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
def
advance_decoding
(
self
,
is_finished
=
False
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
...
...
@@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler:
)
return
None
,
None
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
...
...
@@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler:
return
''
def
rescoring
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
return
logger
.
info
(
"rescoring the final result"
)
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
return
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录