Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
94327238
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
94327238
编写于
5月 23, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor asr online
上级
e6ddb0cc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
200 addition
and
381 deletion
+200
-381
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+200
-381
未找到文件。
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
94327238
...
...
@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
__all__
=
[
'ASREngine'
]
__all__
=
[
'
PaddleASRConnectionHanddler'
,
'ASRServerExecutor'
,
'
ASREngine'
]
# ASR server connection process class
...
...
@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler:
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
if
"deepspeech2
online"
in
self
.
model_type
or
"deepspeech2offline
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
...
...
@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler:
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# frame window samples length and frame shift samples length
# frame window and frame shift, in samples unit
self
.
win_length
=
int
(
self
.
model_config
.
window_ms
/
1000
*
self
.
sample_rate
)
self
.
n_shift
=
int
(
self
.
model_config
.
stride_ms
/
1000
*
...
...
@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler:
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window
samples length and frame shift samples length
# frame window
and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
else
:
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
def
extract_feat
(
self
,
samples
):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if
"deepspeech2online"
in
self
.
model_type
:
# self.reamined_wav stores all the samples,
...
...
@@ -154,28 +153,28 @@ class PaddleASRConnectionHanddler:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
feat
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len
)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
# audio_len is frame num
frame_num
=
feat
.
shape
[
0
]
feat
=
paddle
.
to_tensor
(
feat
,
dtype
=
'float32'
)
feat
=
paddle
.
unsqueeze
(
feat
,
axis
=
0
)
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
audio
self
.
cached_feat
=
feat
else
:
assert
(
len
(
audio
.
shape
)
==
3
)
assert
(
len
(
feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
audio
],
axis
=
1
)
[
self
.
cached_feat
,
feat
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
self
.
num_frames
+=
audio_len
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
audio_len
:]
self
.
num_frames
+=
frame_num
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
frame_num
:]
logger
.
info
(
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
...
...
@@ -183,25 +182,28 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
elif
"conformer_online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data"
)
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The con
nection remain the audio sample
s:
{
self
.
remained_wav
.
shape
}
"
f
"The con
catenation of remain and now audio samples length i
s:
{
self
.
remained_wav
.
shape
}
"
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
# samples not enough for feature window
return
0
# fbank
...
...
@@ -209,11 +211,13 @@ class PaddleASRConnectionHanddler:
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# feature cache
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
assert
(
len
(
x_chunk
.
shape
)
==
3
)
# (B,T,D)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
# (B,T,D)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
...
...
@@ -221,20 +225,30 @@ class PaddleASRConnectionHanddler:
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
# cur frame step
num_frames
=
x_chunk
.
shape
[
1
]
# global frame step
self
.
num_frames
+=
num_frames
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
f
"process the audio feature success, the c
onnection
feat shape:
{
self
.
cached_feat
.
shape
}
"
f
"process the audio feature success, the c
ached
feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the c
onnection
remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
f
"After extract feat, the c
ached
remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
# logger.info(f"accumulate samples: {self.num_samples}")
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
else
:
raise
ValueError
(
f
"not supported:
{
self
.
model_type
}
"
)
def
reset
(
self
):
if
"deepspeech2
online"
in
self
.
model_type
or
"deepspeech2offline
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
# for deepspeech2
self
.
chunk_state_h_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_h_box
)
...
...
@@ -242,35 +256,63 @@ class PaddleASRConnectionHanddler:
self
.
asr_engine
.
executor
.
chunk_state_c_box
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
# for conformer online
self
.
device
=
None
## common
# global sample and frame step
self
.
num_samples
=
0
self
.
num_frames
=
0
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
## conformer
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
num_samples
=
0
self
.
device
=
None
# conformer decoding state
self
.
chunk_num
=
0
# globa decoding chunk num
self
.
offset
=
0
# global offset in decoding frame unit
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
self
.
first_char_occur_elapsed
=
None
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None: nothing
"""
if
"deepspeech2online"
in
self
.
model_type
:
# x_chunk 是特征数据
decoding_chunk_size
=
1
# decoding_chunk_size=1 in deepspeech2 model
context
=
7
# context=7 in deepspeech2 model
subsampling
=
4
# subsampling=4 in deepspeech2 model
stride
=
subsampling
*
decoding_chunk_size
decoding_chunk_size
=
1
# decoding chunk size = 1. int decoding frame unit
context
=
7
# context=7, in audio frame unit
subsampling
=
4
# subsampling=4, in audio frame unit
cached_feature_num
=
context
-
subsampling
# decoding window for model
# decoding window for model
, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
# decoding stride for model, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
...
...
@@ -280,6 +322,7 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
...
...
@@ -293,6 +336,7 @@ class PaddleASRConnectionHanddler:
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
...
...
@@ -302,6 +346,7 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
...
...
@@ -311,7 +356,9 @@ class PaddleASRConnectionHanddler:
self
.
result_transcripts
=
[
trans_best
]
# update feat cache
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
# return trans_best[0]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
try
:
...
...
@@ -326,9 +373,19 @@ class PaddleASRConnectionHanddler:
else
:
raise
Exception
(
"invalid model name"
)
@
paddle
.
no_grad
()
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
):
logger
.
info
(
"start to decoce one chunk with deepspeech2 model"
)
"""forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger
.
info
(
"start to decoce one chunk for deepspeech2"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
...
...
@@ -365,24 +422,31 @@ class PaddleASRConnectionHanddler:
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result:
{
trans_best
[
0
]
}
"
)
logger
.
info
(
f
"decode one best result
for deepspeech2
:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
@
paddle
.
no_grad
()
def
advance_decoding
(
self
,
is_finished
=
False
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
logger
.
info
(
"
Conformer/Transformer:
start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
# cur chunk size, in decoding frame unit
decoding_chunk_size
=
cfg
.
decoding_chunk_size
# using num of history chunks
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
# processed chunk feature cached for next chunk
# decoding window for model
# processed chunk feature cached for next chunk
cached_feature_num
=
context
-
subsampling
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
# decoding window, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
...
...
@@ -407,6 +471,7 @@ class PaddleASRConnectionHanddler:
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
outputs
=
[]
...
...
@@ -423,8 +488,11 @@ class PaddleASRConnectionHanddler:
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# global chunk_num
self
.
chunk_num
+=
1
# cur chunk
chunk_xs
=
self
.
cached_feat
[:,
cur
:
end
,
:]
# forward chunk
(
y
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
=
self
.
model
.
encoder
.
forward_chunk
(
chunk_xs
,
self
.
offset
,
required_cache_size
,
...
...
@@ -432,7 +500,7 @@ class PaddleASRConnectionHanddler:
self
.
conformer_cnn_cache
)
outputs
.
append
(
y
)
# update the
offse
t
# update the
global offset, in decoding frame uni
t
self
.
offset
+=
y
.
shape
[
1
]
ys
=
paddle
.
cat
(
outputs
,
1
)
...
...
@@ -445,12 +513,15 @@ class PaddleASRConnectionHanddler:
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# advance decoding
self
.
searcher
.
search
(
ctc_probs
,
self
.
cached_feat
.
place
)
# get one best hyps
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
end
>=
cached_feature_num
# advance cache of feat
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
cached_feature_num
:,
:].
unsqueeze
(
0
)
assert
len
(
...
...
@@ -462,50 +533,79 @@ class PaddleASRConnectionHanddler:
)
def
update_result
(
self
):
"""Conformer/Transformer hyps to result.
"""
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
# output results and tokenids
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
get_result
(
self
):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if
len
(
self
.
result_transcripts
)
>
0
:
return
self
.
result_transcripts
[
0
]
else
:
return
''
def
get_word_time_stamp
(
self
):
"""return token timestamp result.
Returns:
list: List of ('w':token, 'bg':time, 'ed':time)
"""
return
self
.
word_time_stamp
@
paddle
.
no_grad
()
def
rescoring
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
"""Second-Pass Decoding,
only for conformer and transformer model.
"""
if
"deepspeech2"
in
self
.
model_type
:
logger
.
info
(
"deepspeech2 not support rescoring decoding."
)
return
logger
.
info
(
"rescoring the final result"
)
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
logger
.
info
(
f
"decoding method not match:
{
self
.
ctc_decode_config
.
decoding_method
}
, need attention_rescoring"
)
return
logger
.
info
(
"rescoring the final result"
)
# last decoding for last audio
self
.
searcher
.
finalize_search
()
# update beam search results
self
.
update_result
()
beam_size
=
self
.
ctc_decode_config
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
if
hyps
is
None
or
len
(
hyps
)
==
0
:
logger
.
info
(
"No Hyps!"
)
return
# rescore by decoder post probability
# assert len(hyps) == beam_size
# list of Tensor
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
self
.
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_pad
=
pad_sequence
(
hyp_list
,
batch_first
=
True
,
padding_value
=
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
self
.
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
...
...
@@ -531,10 +631,12 @@ class PaddleASRConnectionHanddler:
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
ctc_decode_config
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
...
...
@@ -542,47 +644,57 @@ class PaddleASRConnectionHanddler:
# update the one best result
# hyps stored the beam results and each fields is:
logger
.
info
(
f
"best index:
{
best_index
}
"
)
logger
.
info
(
f
"best
hyp
index:
{
best_index
}
"
)
# logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is:
## asr results
# hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths
## timestamp
# hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability
# hyps[0][3]: viterbi_non_blank
dending
probability
# hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank,
# hyps[0][6]: times_titerbi_non_blank
# hyps[0][5]: times_viterbi_blank
ending timestamp
,
# hyps[0][6]: times_titerbi_non_blank
encding timestamp.
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
logger
.
info
(
f
"best hyp ids:
{
self
.
hyps
}
"
)
# update the hyps time stamp
self
.
time_stamp
=
hyps
[
best_index
][
5
]
if
hyps
[
best_index
][
2
]
>
hyps
[
best_index
][
3
]
else
hyps
[
best_index
][
6
]
logger
.
info
(
f
"time stamp:
{
self
.
time_stamp
}
"
)
# update one best result
self
.
update_result
()
# update each word start and end time stamp
frame_shift_in_ms
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
*
self
.
n_shift
/
self
.
sample_rate
logger
.
info
(
f
"frame shift ms:
{
frame_shift_in_ms
}
"
)
# decoding frame to audio frame
frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
frame_shift_in_sec
=
frame_shift
*
(
self
.
n_shift
/
self
.
sample_rate
)
logger
.
info
(
f
"frame shift sec:
{
frame_shift_in_sec
}
"
)
word_time_stamp
=
[]
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
)
/
2.0
if
idx
>
0
else
0
start
=
start
*
frame_shift_in_
ms
start
=
start
*
frame_shift_in_
sec
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
end
=
end
*
frame_shift_in_ms
end
=
end
*
frame_shift_in_sec
word_time_stamp
.
append
({
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"bg"
:
start
,
"ed"
:
end
})
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}")
# logger.info(f"{word_time_stamp[-1]}")
self
.
word_time_stamp
=
word_time_stamp
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
class
ASRServerExecutor
(
ASRExecutor
):
def
__init__
(
self
):
super
().
__init__
()
...
...
@@ -610,6 +722,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
sample_rate
=
sample_rate
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
...
...
@@ -628,7 +741,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
self
.
am_params
=
os
.
path
.
abspath
(
am_params
)
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
am_model
)
...
...
@@ -639,7 +752,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
with
UpdateConfig
(
self
.
config
):
if
"deepspeech2
online"
in
model_type
or
"deepspeech2offline
"
in
model_type
:
if
"deepspeech2"
in
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
...
...
@@ -655,6 +768,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
logger
.
info
(
"start to create the stream conformer asr engine"
)
if
self
.
config
.
spm_model_prefix
:
...
...
@@ -682,7 +796,8 @@ class ASRServerExecutor(ASRExecutor):
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
raise
Exception
(
"wrong type"
)
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
if
"deepspeech2"
in
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
...
...
@@ -719,6 +834,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
...
...
@@ -737,277 +853,14 @@ class ASRServerExecutor(ASRExecutor):
# update the ctc decoding
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
config
.
decode
)
self
.
transformer_decode_reset
()
return
True
def
reset_decoder_and_chunk
(
self
):
"""reset decoder and chunk state for an new audio
"""
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
# init state box, for new audio request
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
self
.
transformer_decode_reset
()
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
,
model_type
:
str
):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
str: one best result
"""
logger
.
info
(
"start to decoce chunk by chunk"
)
if
"deepspeech2online"
in
model_type
:
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
3
])
audio_handle
.
reshape
(
x_chunk
.
shape
)
audio_handle
.
copy_from_cpu
(
x_chunk
)
audio_len_handle
.
reshape
(
x_chunk_lens
.
shape
)
audio_len_handle
.
copy_from_cpu
(
x_chunk_lens
)
h_box_handle
.
reshape
(
self
.
chunk_state_h_box
.
shape
)
h_box_handle
.
copy_from_cpu
(
self
.
chunk_state_h_box
)
c_box_handle
.
reshape
(
self
.
chunk_state_c_box
.
shape
)
c_box_handle
.
copy_from_cpu
(
self
.
chunk_state_c_box
)
output_names
=
self
.
am_predictor
.
get_output_names
()
output_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
1
])
output_state_h_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
2
])
output_state_c_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
3
])
self
.
am_predictor
.
run
()
output_chunk_probs
=
output_handle
.
copy_to_cpu
()
output_chunk_lens
=
output_lens_handle
.
copy_to_cpu
()
self
.
chunk_state_h_box
=
output_state_h_handle
.
copy_to_cpu
()
self
.
chunk_state_c_box
=
output_state_c_handle
.
copy_to_cpu
()
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
try
:
logger
.
info
(
f
"we will use the transformer like model :
{
self
.
model_type
}
"
)
self
.
advanced_decoding
(
x_chunk
,
x_chunk_lens
)
self
.
update_result
()
return
self
.
result_transcripts
[
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
raise
Exception
(
"invalid model name"
)
def
advanced_decoding
(
self
,
xs
:
paddle
.
Tensor
,
x_chunk_lens
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
encoder_out
,
encoder_mask
=
self
.
encoder_forward
(
xs
)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
self
.
searcher
.
search
(
ctc_probs
,
xs
.
place
)
# update the one best result
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if
"attention_rescoring"
in
self
.
config
.
decode
.
decoding_method
:
self
.
rescoring
(
encoder_out
,
xs
.
place
)
def
encoder_forward
(
self
,
xs
):
logger
.
info
(
"get the model out from the feat"
)
cfg
=
self
.
config
.
decode
decoding_chunk_size
=
cfg
.
decoding_chunk_size
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
num_frames
=
xs
.
shape
[
1
]
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
logger
.
info
(
"start to do model forward"
)
outputs
=
[]
# num_frames - context + 1 ensure that current frame can get context window
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
(
y
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
=
self
.
model
.
encoder
.
forward_chunk
(
chunk_xs
,
self
.
offset
,
required_cache_size
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
outputs
.
append
(
y
)
self
.
offset
+=
y
.
shape
[
1
]
ys
=
paddle
.
cat
(
outputs
,
1
)
masks
=
paddle
.
ones
([
1
,
ys
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
masks
=
masks
.
unsqueeze
(
1
)
return
ys
,
masks
def
rescoring
(
self
,
encoder_out
,
device
):
logger
.
info
(
"start to rescoring the hyps"
)
beam_size
=
self
.
config
.
decode
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
assert
len
(
hyps
)
==
beam_size
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
model
.
sos
,
self
.
model
.
eos
,
self
.
model
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
config
.
decode
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
# update the one best result
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
return
hyps
[
best_index
][
0
]
def
transformer_decode_reset
(
self
):
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
offset
=
0
# decoding reset
self
.
searcher
.
reset
()
def
update_result
(
self
):
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
extract_feat
(
self
,
samples
,
sample_rate
):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
if
"deepspeech2online"
in
self
.
model_type
:
# pcm16 -> pcm 32
samples
=
pcm2float
(
samples
)
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
sample_rate
,
transcript
=
" "
)
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
x_chunk
=
audio
.
numpy
()
x_chunk_lens
=
np
.
array
([
audio_len
])
return
x_chunk
,
x_chunk_lens
elif
"conformer_online"
in
self
.
model_type
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
"the model sample_rate is {self.sample_rate}"
)
logger
.
info
(
f
"ASR Engine use the
{
self
.
model_type
}
to process"
)
logger
.
info
(
"Create the preprocess instance"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocessing
=
Transformation
(
preprocess_conf
)
logger
.
info
(
"Read the audio file"
)
logger
.
info
(
f
"audio shape:
{
samples
.
shape
}
"
)
# fbank
x_chunk
=
preprocessing
(
samples
,
**
preprocess_args
)
x_chunk_lens
=
paddle
.
to_tensor
(
x_chunk
.
shape
[
0
])
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
logger
.
info
(
f
"process the audio feature success, feat shape:
{
x_chunk
.
shape
}
"
)
return
x_chunk
,
x_chunk_lens
raise
ValueError
(
f
"Not support:
{
model_type
}
"
)
return
True
class
ASREngine
(
BaseEngine
):
"""ASR server
engin
e
"""ASR server
resourc
e
Args:
metaclass: Defaults to Singleton.
...
...
@@ -1015,7 +868,7 @@ class ASREngine(BaseEngine):
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
logger
.
info
(
"create the online asr engine instance"
)
logger
.
info
(
"create the online asr engine
resource
instance"
)
def
init
(
self
,
config
:
dict
)
->
bool
:
"""init engine resource
...
...
@@ -1026,17 +879,12 @@ class ASREngine(BaseEngine):
Returns:
bool: init failed or success
"""
self
.
input
=
None
self
.
output
=
""
self
.
executor
=
ASRServerExecutor
()
self
.
config
=
config
self
.
executor
=
ASRServerExecutor
()
try
:
if
self
.
config
.
get
(
"device"
,
None
):
self
.
device
=
self
.
config
.
device
else
:
self
.
device
=
paddle
.
get_device
()
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
paddle
.
set_device
(
self
.
device
)
default_dev
=
paddle
.
get_device
()
paddle
.
set_device
(
self
.
config
.
get
(
"device"
,
default_dev
))
except
BaseException
as
e
:
logger
.
error
(
f
"Set device failed, please check if device '
{
self
.
device
}
' is already used and the parameter 'device' in the yaml file"
...
...
@@ -1045,6 +893,8 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model_type
,
am_model
=
self
.
config
.
am_model
,
...
...
@@ -1062,42 +912,11 @@ class ASREngine(BaseEngine):
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
def
preprocess
(
self
,
samples
,
sample_rate
,
model_type
=
"deepspeech2online_aishell-zh-16k"
):
"""preprocess
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# if "deepspeech" in model_type:
x_chunk
,
x_chunk_lens
=
self
.
executor
.
extract_feat
(
samples
,
sample_rate
)
return
x_chunk
,
x_chunk_lens
def
preprocess
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
def
run
(
self
,
x_chunk
,
x_chunk_lens
,
decoder_chunk_size
=
1
):
"""run online engine
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self
.
output
=
self
.
executor
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
,
self
.
config
.
model_type
)
def
run
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
def
postprocess
(
self
):
"""postprocess
"""
return
self
.
output
def
reset
(
self
):
"""reset engine decoder and inference state
"""
self
.
executor
.
reset_decoder_and_chunk
()
self
.
output
=
""
raise
NotImplementedError
(
"Online not using this."
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录