Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8f9b7bba
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8f9b7bba
编写于
6月 07, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor asr online server
上级
f3132ce2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
356 addition
and
232 deletion
+356
-232
.pre-commit-config.yaml
.pre-commit-config.yaml
+6
-6
demos/streaming_asr_server/server.sh
demos/streaming_asr_server/server.sh
+2
-1
demos/streaming_asr_server/test.sh
demos/streaming_asr_server/test.sh
+2
-1
paddlespeech/__init__.py
paddlespeech/__init__.py
+4
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+199
-209
paddlespeech/server/engine/asr/online/ctc_endpoint.py
paddlespeech/server/engine/asr/online/ctc_endpoint.py
+108
-0
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+35
-15
未找到文件。
.pre-commit-config.yaml
浏览文件 @
8f9b7bba
...
...
@@ -51,12 +51,12 @@ repos:
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude
:
(?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
-
id
:
copyright_checker
name
:
copyright_checker
entry
:
python .pre-commit-hooks/copyright-check.hook
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude
:
(?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
#
- id: copyright_checker
#
name: copyright_checker
#
entry: python .pre-commit-hooks/copyright-check.hook
#
language: system
#
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
#
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
-
repo
:
https://github.com/asottile/reorder_python_imports
rev
:
v2.4.0
hooks
:
...
...
demos/streaming_asr_server/server.sh
浏览文件 @
8f9b7bba
...
...
@@ -5,4 +5,5 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3
paddlespeech_server start
--config_file
conf/punc_application.yaml &> punc.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start
--config_file
conf/ws_conformer_application.yaml &> streaming_asr.log &
\ No newline at end of file
paddlespeech_server start
--config_file
conf/ws_conformer_application.yaml &> streaming_asr.log &
demos/streaming_asr_server/test.sh
浏览文件 @
8f9b7bba
...
...
@@ -9,4 +9,5 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa
# read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
paddlespeech_client asr_online
--server_ip
127.0.0.1
--port
8290
--punc
.server_ip 127.0.0.1
--punc
.port 8190
--input
./zh.wav
\ No newline at end of file
paddlespeech_client asr_online
--server_ip
127.0.0.1
--port
8290
--punc
.server_ip 127.0.0.1
--punc
.port 8190
--input
./zh.wav
paddlespeech/__init__.py
浏览文件 @
8f9b7bba
...
...
@@ -14,3 +14,7 @@
import
_locale
_locale
.
_getdefaultlocale
=
(
lambda
*
args
:
[
'en_US'
,
'utf8'
])
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
8f9b7bba
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
os
import
sys
from
typing
import
ByteString
from
typing
import
Optional
import
numpy
as
np
...
...
@@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation
from
paddlespeech.s2t.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.s2t.utils.tensor_utils
import
pad_sequence
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.asr.online.ctc_endpoint
import
OnlineCTCEndpoingOpt
from
paddlespeech.server.engine.asr.online.ctc_endpoint
import
OnlineCTCEndpoint
from
paddlespeech.server.engine.asr.online.ctc_search
import
CTCPrefixBeamSearch
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
__all__
=
[
'PaddleASRConnectionHanddler'
,
'ASRServerExecutor'
,
'ASREngine'
]
...
...
@@ -54,24 +56,33 @@ class PaddleASRConnectionHanddler:
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
asr_engine
=
asr_engine
self
.
init
()
self
.
reset
()
def
init
(
self
):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
# extract feat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
assert
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
==
self
.
sample_rate
,
(
self
.
sample_rate
,
self
.
preprocess_conf
.
process
[
0
][
'fs'
])
self
.
frame_shift_in_ms
=
int
(
self
.
n_shift
/
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
*
1000
)
self
.
init_decoder
()
self
.
reset
()
def
init_decoder
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
# extract feat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
model_config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
model_config
.
rnn_layer_size
*
2
,
...
...
@@ -90,10 +101,6 @@ class PaddleASRConnectionHanddler:
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
...
...
@@ -102,130 +109,40 @@ class PaddleASRConnectionHanddler:
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
ctc_decode_config
)
# extract feat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
# ctc endpoint
self
.
endpoint_opt
=
OnlineCTCEndpoingOpt
(
frame_shift_in_ms
=
self
.
frame_shift_in_ms
,
blank
=
0
)
self
.
endpointer
=
OnlineCTCEndpoint
(
self
.
endpoint_opt
)
else
:
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
def
extract_feat
(
self
,
samples
):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
if
"deepspeech2online"
in
self
.
model_type
:
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
# fbank
feat
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
feat
=
paddle
.
to_tensor
(
feat
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
feat
else
:
assert
(
len
(
feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
feat
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
# cur frame step
num_frames
=
feat
.
shape
[
1
]
self
.
num_frames
+=
num_frames
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
elif
"conformer_online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
# samples not enough for feature window
return
0
# fbank
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# feature cache
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
# (B,T,D)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
# (B,T,D)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
def
model_reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
return
# cur frame step
num_frames
=
x_chunk
.
shape
[
1
]
# feature cache
self
.
cached_feat
=
None
# global frame step
self
.
num_frames
+=
num_frames
## conformer
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
# conformer decoding state
self
.
offset
=
0
# global offset in decoding frame unit
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
## just for record info
self
.
chunk_num
=
0
# global decoding chunk num, not used
logger
.
info
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
else
:
raise
ValueError
(
f
"not supported:
{
self
.
model_type
}
"
)
def
reset_continuous_decoding
(
self
):
"""
when in continous decoding, reset for next utterance.
"""
self
.
global_frame_offset
=
self
.
num_frames
self
.
model_reset
()
self
.
searcher
.
reset
()
self
.
endpointer
.
reset
()
def
reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
...
...
@@ -241,53 +158,110 @@ class PaddleASRConnectionHanddler:
dtype
=
float32
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
if
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
self
.
searcher
.
reset
()
self
.
endpointer
.
reset
()
self
.
device
=
None
## common
# global sample and frame step
self
.
num_samples
=
0
self
.
global_frame_offset
=
0
# frame step of cur utterance
self
.
num_frames
=
0
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
## conformer
self
.
model_reset
()
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
# conformer decoding state
self
.
chunk_num
=
0
# globa decoding chunk num
self
.
offset
=
0
# global offset in decoding frame unit
self
.
hyps
=
[]
## outputs
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
## just for record
self
.
hyps
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
def
extract_feat
(
self
,
samples
:
ByteString
):
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
# samples not enough for feature window
return
0
# fbank
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# feature cache
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
# (B,T,D)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
# (B,T,D)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
# cur frame step
num_frames
=
x_chunk
.
shape
[
1
]
# global frame step
self
.
num_frames
+=
num_frames
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None:
nothing
None:
"""
if
"deepspeech2
online
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
decoding_chunk_size
=
1
# decoding chunk size = 1. int decoding frame unit
context
=
7
# context=7, in audio frame unit
subsampling
=
4
# subsampling=4, in audio frame unit
...
...
@@ -332,9 +306,11 @@ class PaddleASRConnectionHanddler:
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
x_chunk
=
self
.
cached_feat
[:,
cur
:
end
,
:].
numpy
()
x_chunk_lens
=
np
.
array
([
x_chunk
.
shape
[
1
]])
trans_best
=
self
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
)
self
.
result_transcripts
=
[
trans_best
]
...
...
@@ -409,31 +385,38 @@ class PaddleASRConnectionHanddler:
@
paddle
.
no_grad
()
def
advance_decoding
(
self
,
is_finished
=
False
):
if
"deepspeech"
in
self
.
model_type
:
return
logger
.
info
(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
# cur chunk size, in decoding frame unit
# cur chunk size, in decoding frame unit
, e.g. 16
decoding_chunk_size
=
cfg
.
decoding_chunk_size
# using num of history chunks
# using num of history chunks
, e.g -1
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
# e.g. 4
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
# e.g. 7
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
# processed chunk feature cached for next chunk
# processed chunk feature cached for next chunk
, e.g. 3
cached_feature_num
=
context
-
subsampling
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
# decoding window, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
# (B=1,T,D)
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
...
...
@@ -454,9 +437,6 @@ class PaddleASRConnectionHanddler:
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
outputs
=
[]
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
...
...
@@ -466,7 +446,11 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
# record the end for removing the processed feat
outputs
=
[]
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
...
...
@@ -491,30 +475,28 @@ class PaddleASRConnectionHanddler:
self
.
encoder_out
=
ys
else
:
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
logger
.
info
(
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
# get the ctc probs
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
## decoding
# advance decoding
self
.
searcher
.
search
(
ctc_probs
,
self
.
cached_feat
.
place
)
# get one best hyps
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
end
>=
cached_feature_num
# advance cache of feat
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
cached_feature_num
:,
:].
unsqueeze
(
0
)
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
assert
end
>=
cached_feature_num
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
assert
len
(
self
.
cached_feat
.
shape
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
logger
.
info
(
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
def
update_result
(
self
):
"""Conformer/Transformer hyps to result.
"""
...
...
@@ -654,24 +636,28 @@ class PaddleASRConnectionHanddler:
# update each word start and end time stamp
# decoding frame to audio frame
frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
frame_shift_in_sec
=
frame_shift
*
(
self
.
n_shift
/
self
.
sample_rate
)
logger
.
info
(
f
"frame shift sec:
{
frame_shift_in_sec
}
"
)
decode_frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
decode_frame_shift_in_sec
=
decode_frame_shift
*
(
self
.
n_shift
/
self
.
sample_rate
)
logger
.
info
(
f
"decode frame shift in sec:
{
decode_frame_shift_in_sec
}
"
)
global_offset_in_sec
=
self
.
global_frame_offset
*
self
.
frame_shift_in_ms
/
1000.0
logger
.
info
(
f
"global offset:
{
global_offset_in_sec
}
sec."
)
word_time_stamp
=
[]
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
)
/
2.0
if
idx
>
0
else
0
start
=
start
*
frame_shift_in_sec
start
=
start
*
decode_
frame_shift_in_sec
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
end
=
end
*
frame_shift_in_sec
end
=
end
*
decode_
frame_shift_in_sec
word_time_stamp
.
append
({
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"bg"
:
start
,
"ed"
:
end
"bg"
:
global_offset_in_sec
+
start
,
"ed"
:
global_offset_in_sec
+
end
})
# logger.info(f"{word_time_stamp[-1]}")
...
...
@@ -705,13 +691,14 @@ class ASRServerExecutor(ASRExecutor):
self
.
model_type
=
model_type
self
.
sample_rate
=
sample_rate
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
self
.
task_resource
.
set_task_model
(
model_tag
=
tag
)
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
self
.
res_path
=
self
.
task_resource
.
res_dir
self
.
cfg_path
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'cfg_path'
])
...
...
@@ -719,7 +706,6 @@ class ASRServerExecutor(ASRExecutor):
self
.
task_resource
.
res_dict
[
'model'
])
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'params'
])
logger
.
info
(
self
.
res_path
)
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
...
...
@@ -727,9 +713,12 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
info
(
f
" am_params path:
{
self
.
am_params
}
"
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
...
@@ -738,25 +727,39 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
vocab
=
self
.
config
.
vocab_filepath
with
UpdateConfig
(
self
.
config
):
if
"deepspeech2"
in
model_type
:
if
"deepspeech2"
in
model_type
:
with
UpdateConfig
(
self
.
config
):
# download lm
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
MODEL_HOME
,
'language_model'
,
self
.
config
.
decode
.
lang_model_path
)
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
with
UpdateConfig
(
self
.
config
):
logger
.
info
(
"start to create the stream conformer asr engine"
)
# update the decoding method
if
decode_method
:
...
...
@@ -770,37 +773,24 @@ class ASRServerExecutor(ASRExecutor):
logger
.
info
(
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
assert
self
.
config
.
decode
.
decoding_method
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
raise
Exception
(
"wrong type"
)
if
"deepspeech2"
in
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
# load model
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
logger
.
info
(
f
"model name:
{
model_name
}
"
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model_conf
=
self
.
config
model
=
model_class
.
from_config
(
model_conf
)
model
=
model_class
.
from_config
(
self
.
config
)
self
.
model
=
model
self
.
model
.
set_state_dict
(
paddle
.
load
(
self
.
am_model
))
self
.
model
.
eval
()
# load model
model_dict
=
paddle
.
load
(
self
.
am_model
)
self
.
model
.
set_state_dict
(
model_dict
)
logger
.
info
(
"create the transformer like model success"
)
else
:
raise
ValueError
(
f
"N
ot support:
{
model_type
}
"
)
raise
Exception
(
f
"n
ot support:
{
model_type
}
"
)
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
return
True
...
...
paddlespeech/server/engine/asr/online/ctc_endpoint.py
0 → 100644
浏览文件 @
8f9b7bba
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
List
from
paddlespeech.cli.log
import
logger
@
dataclass
class
OnlineCTCEndpointRule
:
must_contain_nonsilence
:
bool
=
True
min_trailing_silence
:
int
=
1000
min_utterance_length
:
int
=
0
@
dataclass
class
OnlineCTCEndpoingOpt
:
frame_shift_in_ms
:
int
=
10
blank
:
int
=
0
# blank id, that we consider as silence for purposes of endpointing.
blank_threshold
:
float
=
0.8
# above blank threshold is silence
# We support three rules. We terminate decoding if ANY of these rules
# evaluates to "true". If you want to add more rules, do it by changing this
# code. If you want to disable a rule, you can set the silence-timeout for
# that rule to a very large number.
# rule1 times out after 5 seconds of silence, even if we decoded nothing.
rule1
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
False
,
5000
,
0
)
# rule4 times out after 1.0 seconds of silence after decoding something,
# even if we did not reach a final-state at all.
rule2
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
True
,
1000
,
0
)
# rule5 times out after the utterance is 20 seconds long, regardless of
# anything else.
rule3
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
False
,
0
,
20000
)
class
OnlineCTCEndpoint
:
"""
[END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf)
"""
def
__init__
(
self
,
opts
:
OnlineCTCEndpoingOpt
):
self
.
opts
=
opts
logger
.
info
(
f
"Endpont Opts:
{
opts
}
"
)
self
.
frame_shift_in_ms
=
opts
.
frame_shift_in_ms
self
.
num_frames_decoded
=
0
self
.
trailing_silence_frames
=
0
self
.
reset
()
def
reset
(
self
):
self
.
num_frames_decoded
=
0
self
.
trailing_silence_frames
=
0
def
rule_activated
(
self
,
rule
:
OnlineCTCEndpointRule
,
rule_name
:
str
,
decoding_something
:
bool
,
trailine_silence
:
int
,
utterance_length
:
int
)
->
bool
:
ans
=
(
decoding_something
or
(
not
rule
.
must_contain_nonsilence
)
)
and
trailine_silence
>=
rule
.
min_trailing_silence
and
utterance_length
>=
rule
.
min_utterance_length
if
(
ans
):
logger
.
info
(
f
"Endpoint Rule:
{
rule_name
}
activated:
{
decoding_something
}
,
{
trailine_silence
}
,
{
utterance_length
}
"
)
return
ans
def
endpoint_detected
(
ctc_log_probs
:
List
[
List
[
float
]],
decoding_something
:
bool
)
->
bool
:
for
logprob
in
ctc_log_probs
:
blank_prob
=
exp
(
logprob
[
self
.
opts
.
blank_id
])
self
.
num_frames_decoded
+=
1
if
blank_prob
>
self
.
opts
.
blank_threshold
:
self
.
trailing_silence_frames
+=
1
else
:
self
.
trailing_silence_frames
=
0
assert
self
.
num_frames_decoded
>=
self
.
trailing_silence_frames
assert
self
.
frame_shift_in_ms
>
0
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
if
self
.
rule_activated
(
self
.
opts
.
rule1
,
'rule1'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
if
self
.
rule_activated
(
self
.
opts
.
rule2
,
'rule2'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
if
self
.
rule_activated
(
self
.
opts
.
rule3
,
'rule3'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
return
False
paddlespeech/server/engine/asr/online/ctc_search.py
浏览文件 @
8f9b7bba
...
...
@@ -30,8 +30,29 @@ class CTCPrefixBeamSearch:
config (yacs.config.CfgNode): the ctc prefix beam search configuration
"""
self
.
config
=
config
# beam size
self
.
first_beam_size
=
self
.
config
.
beam_size
# TODO(support second beam size)
self
.
second_beam_size
=
int
(
self
.
first_beam_size
*
1.0
)
logger
.
info
(
f
"first and second beam size:
{
self
.
first_beam_size
}
,
{
self
.
second_beam_size
}
"
)
# state
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
self
.
reset
()
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
@
paddle
.
no_grad
()
def
search
(
self
,
ctc_probs
,
device
,
blank_id
=
0
):
"""ctc prefix beam search method decode a chunk feature
...
...
@@ -47,12 +68,17 @@ class CTCPrefixBeamSearch:
"""
# decode
logger
.
info
(
"start to ctc prefix search"
)
assert
len
(
ctc_probs
.
shape
)
==
2
batch_size
=
1
beam_size
=
self
.
config
.
beam_size
maxlen
=
ctc_probs
.
shape
[
0
]
assert
len
(
ctc_probs
.
shape
)
==
2
vocab_size
=
ctc_probs
.
shape
[
1
]
first_beam_size
=
min
(
self
.
first_beam_size
,
vocab_size
)
second_beam_size
=
min
(
self
.
second_beam_size
,
vocab_size
)
logger
.
info
(
f
"effect first and second beam size:
{
self
.
first_beam_size
}
,
{
self
.
second_beam_size
}
"
)
maxlen
=
ctc_probs
.
shape
[
0
]
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# 0. blank_ending_score,
...
...
@@ -75,7 +101,8 @@ class CTCPrefixBeamSearch:
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
top_k_logp
,
top_k_index
=
logp
.
topk
(
first_beam_size
)
# (first_beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
...
...
@@ -148,7 +175,7 @@ class CTCPrefixBeamSearch:
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
([
x
[
1
][
0
],
x
[
1
][
1
]]),
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
self
.
cur_hyps
=
next_hyps
[:
second_
beam_size
]
# 2.3 update the absolute time step
self
.
abs_time_step
+=
1
...
...
@@ -163,7 +190,7 @@ class CTCPrefixBeamSearch:
"""Return the one best result
Returns:
list: the one best result
list: the one best result
, List[str]
"""
return
[
self
.
hyps
[
0
][
0
]]
...
...
@@ -171,17 +198,10 @@ class CTCPrefixBeamSearch:
"""Return the search hyps
Returns:
list: return the search hyps
list: return the search hyps
, List[Tuple[str, float, ...]]
"""
return
self
.
hyps
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
def
finalize_search
(
self
):
"""do nothing in ctc_prefix_beam_search
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录