Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8f9b7bba
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8f9b7bba
编写于
6月 07, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor asr online server
上级
f3132ce2
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
356 addition
and
232 deletion
+356
-232
.pre-commit-config.yaml
.pre-commit-config.yaml
+6
-6
demos/streaming_asr_server/server.sh
demos/streaming_asr_server/server.sh
+2
-1
demos/streaming_asr_server/test.sh
demos/streaming_asr_server/test.sh
+2
-1
paddlespeech/__init__.py
paddlespeech/__init__.py
+4
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+199
-209
paddlespeech/server/engine/asr/online/ctc_endpoint.py
paddlespeech/server/engine/asr/online/ctc_endpoint.py
+108
-0
paddlespeech/server/engine/asr/online/ctc_search.py
paddlespeech/server/engine/asr/online/ctc_search.py
+35
-15
未找到文件。
.pre-commit-config.yaml
浏览文件 @
8f9b7bba
...
@@ -51,12 +51,12 @@ repos:
...
@@ -51,12 +51,12 @@ repos:
language
:
system
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude
:
(?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
exclude
:
(?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
-
id
:
copyright_checker
#
- id: copyright_checker
name
:
copyright_checker
#
name: copyright_checker
entry
:
python .pre-commit-hooks/copyright-check.hook
#
entry: python .pre-commit-hooks/copyright-check.hook
language
:
system
#
language: system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
#
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude
:
(?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
#
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
-
repo
:
https://github.com/asottile/reorder_python_imports
-
repo
:
https://github.com/asottile/reorder_python_imports
rev
:
v2.4.0
rev
:
v2.4.0
hooks
:
hooks
:
...
...
demos/streaming_asr_server/server.sh
浏览文件 @
8f9b7bba
...
@@ -6,3 +6,4 @@ paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log &
...
@@ -6,3 +6,4 @@ paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start
--config_file
conf/ws_conformer_application.yaml &> streaming_asr.log &
paddlespeech_server start
--config_file
conf/ws_conformer_application.yaml &> streaming_asr.log &
demos/streaming_asr_server/test.sh
浏览文件 @
8f9b7bba
...
@@ -10,3 +10,4 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa
...
@@ -10,3 +10,4 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
paddlespeech_client asr_online
--server_ip
127.0.0.1
--port
8290
--punc
.server_ip 127.0.0.1
--punc
.port 8190
--input
./zh.wav
paddlespeech_client asr_online
--server_ip
127.0.0.1
--port
8290
--punc
.server_ip 127.0.0.1
--punc
.port 8190
--input
./zh.wav
paddlespeech/__init__.py
浏览文件 @
8f9b7bba
...
@@ -14,3 +14,7 @@
...
@@ -14,3 +14,7 @@
import
_locale
import
_locale
_locale
.
_getdefaultlocale
=
(
lambda
*
args
:
[
'en_US'
,
'utf8'
])
_locale
.
_getdefaultlocale
=
(
lambda
*
args
:
[
'en_US'
,
'utf8'
])
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
8f9b7bba
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
import
sys
from
typing
import
ByteString
from
typing
import
Optional
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
...
@@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation
...
@@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation
from
paddlespeech.s2t.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.s2t.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.s2t.utils.tensor_utils
import
pad_sequence
from
paddlespeech.s2t.utils.tensor_utils
import
pad_sequence
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.asr.online.ctc_endpoint
import
OnlineCTCEndpoingOpt
from
paddlespeech.server.engine.asr.online.ctc_endpoint
import
OnlineCTCEndpoint
from
paddlespeech.server.engine.asr.online.ctc_search
import
CTCPrefixBeamSearch
from
paddlespeech.server.engine.asr.online.ctc_search
import
CTCPrefixBeamSearch
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
__all__
=
[
'PaddleASRConnectionHanddler'
,
'ASRServerExecutor'
,
'ASREngine'
]
__all__
=
[
'PaddleASRConnectionHanddler'
,
'ASRServerExecutor'
,
'ASREngine'
]
...
@@ -54,24 +56,33 @@ class PaddleASRConnectionHanddler:
...
@@ -54,24 +56,33 @@ class PaddleASRConnectionHanddler:
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
asr_engine
=
asr_engine
self
.
asr_engine
=
asr_engine
self
.
init
()
self
.
reset
()
def
init
(
self
):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
# tokens to text
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
if
"deepspeech2"
in
self
.
model_type
:
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
# extract feat, new only fbank in conformer model
# extract feat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
assert
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
==
self
.
sample_rate
,
(
self
.
sample_rate
,
self
.
preprocess_conf
.
process
[
0
][
'fs'
])
self
.
frame_shift_in_ms
=
int
(
self
.
n_shift
/
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
*
1000
)
self
.
init_decoder
()
self
.
reset
()
def
init_decoder
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
decoder
=
CTCDecoder
(
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
model_config
.
output_dim
,
# <blank> is in vocab
odim
=
self
.
model_config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
model_config
.
rnn_layer_size
*
2
,
enc_n_units
=
self
.
model_config
.
rnn_layer_size
*
2
,
...
@@ -90,10 +101,6 @@ class PaddleASRConnectionHanddler:
...
@@ -90,10 +101,6 @@ class PaddleASRConnectionHanddler:
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
cfg
.
num_proc_bsearch
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
# acoustic model
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
self
.
model
=
self
.
asr_engine
.
executor
.
model
...
@@ -102,68 +109,88 @@ class PaddleASRConnectionHanddler:
...
@@ -102,68 +109,88 @@ class PaddleASRConnectionHanddler:
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
ctc_decode_config
)
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
ctc_decode_config
)
# extract feat, new only fbank in conformer model
# ctc endpoint
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
endpoint_opt
=
OnlineCTCEndpoingOpt
(
self
.
preprocess_args
=
{
"train"
:
False
}
frame_shift_in_ms
=
self
.
frame_shift_in_ms
,
blank
=
0
)
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
endpointer
=
OnlineCTCEndpoint
(
self
.
endpoint_opt
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
else
:
else
:
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
def
extract_feat
(
self
,
samples
):
def
model_reset
(
self
):
# we compute the elapsed time of first char occuring
if
"deepspeech2"
in
self
.
model_type
:
# and we record the start time at the first pcm sample arraving
return
if
"deepspeech2online"
in
self
.
model_type
:
# feature cache
# self.reamined_wav stores all the samples,
self
.
cached_feat
=
None
# include the original remained_wav and this package samples
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
if
self
.
remained_wav
is
None
:
## conformer
self
.
remained_wav
=
samples
# cache for conformer online
else
:
self
.
subsampling_cache
=
None
assert
self
.
remained_wav
.
ndim
==
1
self
.
elayers_output_cache
=
None
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
self
.
conformer_cnn_cache
=
None
logger
.
info
(
self
.
encoder_out
=
None
f
"The connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
# conformer decoding state
)
self
.
offset
=
0
# global offset in decoding frame unit
# fbank
## just for record info
feat
=
self
.
preprocessing
(
self
.
remained_wav
,
self
.
chunk_num
=
0
# global decoding chunk num, not used
**
self
.
preprocess_args
)
feat
=
paddle
.
to_tensor
(
feat
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
if
self
.
cached_feat
is
None
:
def
reset_continuous_decoding
(
self
):
self
.
cached_feat
=
feat
"""
else
:
when in continous decoding, reset for next utterance.
assert
(
len
(
feat
.
shape
)
==
3
)
"""
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
global_frame_offset
=
self
.
num_frames
self
.
cached_feat
=
paddle
.
concat
(
self
.
model_reset
()
[
self
.
cached_feat
,
feat
],
axis
=
1
)
self
.
searcher
.
reset
()
self
.
endpointer
.
reset
()
# set the feat device
def
reset
(
self
):
if
self
.
device
is
None
:
if
"deepspeech2"
in
self
.
model_type
:
self
.
device
=
self
.
cached_feat
.
place
# for deepspeech2
# init state
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
model_config
.
num_rnn_layers
,
1
,
self
.
model_config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
model_config
.
num_rnn_layers
,
1
,
self
.
model_config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
# cur frame step
if
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
num_frames
=
feat
.
shape
[
1
]
self
.
searcher
.
reset
()
self
.
endpointer
.
reset
()
self
.
num_frames
+=
num_frames
self
.
device
=
None
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
## common
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
# global sample and frame step
)
self
.
num_samples
=
0
logger
.
info
(
self
.
global_frame_offset
=
0
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
# frame step of cur utterance
)
self
.
num_frames
=
0
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
## conformer
self
.
model_reset
()
## outputs
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
## just for record
self
.
hyps
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
elif
"conformer_online"
in
self
.
model_type
:
def
extract_feat
(
self
,
samples
:
ByteString
)
:
logger
.
info
(
"Online ASR extract the feat"
)
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
assert
samples
.
ndim
==
1
...
@@ -189,10 +216,8 @@ class PaddleASRConnectionHanddler:
...
@@ -189,10 +216,8 @@ class PaddleASRConnectionHanddler:
return
0
return
0
# fbank
# fbank
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# feature cache
# feature cache
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
...
@@ -224,55 +249,6 @@ class PaddleASRConnectionHanddler:
...
@@ -224,55 +249,6 @@ class PaddleASRConnectionHanddler:
)
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
else
:
raise
ValueError
(
f
"not supported:
{
self
.
model_type
}
"
)
def
reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
# for deepspeech2
# init state
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
model_config
.
num_rnn_layers
,
1
,
self
.
model_config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
model_config
.
num_rnn_layers
,
1
,
self
.
model_config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
self
.
device
=
None
## common
# global sample and frame step
self
.
num_samples
=
0
self
.
num_frames
=
0
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
## conformer
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
# conformer decoding state
self
.
chunk_num
=
0
# globa decoding chunk num
self
.
offset
=
0
# global offset in decoding frame unit
self
.
hyps
=
[]
# token timestamp result
self
.
word_time_stamp
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
def
decode
(
self
,
is_finished
=
False
):
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
"""advance decoding
...
@@ -280,14 +256,12 @@ class PaddleASRConnectionHanddler:
...
@@ -280,14 +256,12 @@ class PaddleASRConnectionHanddler:
Args:
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
Returns:
None:
nothing
None:
"""
"""
if
"deepspeech2
online
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
decoding_chunk_size
=
1
# decoding chunk size = 1. int decoding frame unit
decoding_chunk_size
=
1
# decoding chunk size = 1. int decoding frame unit
context
=
7
# context=7, in audio frame unit
context
=
7
# context=7, in audio frame unit
subsampling
=
4
# subsampling=4, in audio frame unit
subsampling
=
4
# subsampling=4, in audio frame unit
...
@@ -332,9 +306,11 @@ class PaddleASRConnectionHanddler:
...
@@ -332,9 +306,11 @@ class PaddleASRConnectionHanddler:
end
=
None
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
# extract the audio
x_chunk
=
self
.
cached_feat
[:,
cur
:
end
,
:].
numpy
()
x_chunk
=
self
.
cached_feat
[:,
cur
:
end
,
:].
numpy
()
x_chunk_lens
=
np
.
array
([
x_chunk
.
shape
[
1
]])
x_chunk_lens
=
np
.
array
([
x_chunk
.
shape
[
1
]])
trans_best
=
self
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
)
trans_best
=
self
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
)
self
.
result_transcripts
=
[
trans_best
]
self
.
result_transcripts
=
[
trans_best
]
...
@@ -409,31 +385,38 @@ class PaddleASRConnectionHanddler:
...
@@ -409,31 +385,38 @@ class PaddleASRConnectionHanddler:
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
advance_decoding
(
self
,
is_finished
=
False
):
def
advance_decoding
(
self
,
is_finished
=
False
):
if
"deepspeech"
in
self
.
model_type
:
return
logger
.
info
(
logger
.
info
(
"Conformer/Transformer: start to decode with advanced_decoding method"
"Conformer/Transformer: start to decode with advanced_decoding method"
)
)
cfg
=
self
.
ctc_decode_config
cfg
=
self
.
ctc_decode_config
# cur chunk size, in decoding frame unit
# cur chunk size, in decoding frame unit
, e.g. 16
decoding_chunk_size
=
cfg
.
decoding_chunk_size
decoding_chunk_size
=
cfg
.
decoding_chunk_size
# using num of history chunks
# using num of history chunks
, e.g -1
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
assert
decoding_chunk_size
>
0
# e.g. 4
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
# e.g. 7
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
# processed chunk feature cached for next chunk
# processed chunk feature cached for next chunk
, e.g. 3
cached_feature_num
=
context
-
subsampling
cached_feature_num
=
context
-
subsampling
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
# decoding window, in audio frame unit
# decoding window, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
return
# (B=1,T,D)
num_frames
=
self
.
cached_feat
.
shape
[
1
]
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
...
@@ -454,9 +437,6 @@ class PaddleASRConnectionHanddler:
...
@@ -454,9 +437,6 @@ class PaddleASRConnectionHanddler:
return
None
,
None
return
None
,
None
logger
.
info
(
"start to do model forward"
)
logger
.
info
(
"start to do model forward"
)
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
outputs
=
[]
# num_frames - context + 1 ensure that current frame can get context window
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
if
is_finished
:
...
@@ -466,7 +446,11 @@ class PaddleASRConnectionHanddler:
...
@@ -466,7 +446,11 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
left_frames
=
decoding_window
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
# record the end for removing the processed feat
# record the end for removing the processed feat
outputs
=
[]
end
=
None
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
end
=
min
(
cur
+
decoding_window
,
num_frames
)
...
@@ -491,30 +475,28 @@ class PaddleASRConnectionHanddler:
...
@@ -491,30 +475,28 @@ class PaddleASRConnectionHanddler:
self
.
encoder_out
=
ys
self
.
encoder_out
=
ys
else
:
else
:
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
logger
.
info
(
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
# get the ctc probs
# get the ctc probs
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
## decoding
# advance decoding
# advance decoding
self
.
searcher
.
search
(
ctc_probs
,
self
.
cached_feat
.
place
)
self
.
searcher
.
search
(
ctc_probs
,
self
.
cached_feat
.
place
)
# get one best hyps
# get one best hyps
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
end
>=
cached_feature_num
# advance cache of feat
# advance cache of feat
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
cached_feature_num
:,
:].
unsqueeze
(
0
)
assert
end
>=
cached_feature_num
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
assert
len
(
assert
len
(
self
.
cached_feat
.
shape
self
.
cached_feat
.
shape
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
)
==
3
,
f
"current cache feat shape is:
{
self
.
cached_feat
.
shape
}
"
logger
.
info
(
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
def
update_result
(
self
):
def
update_result
(
self
):
"""Conformer/Transformer hyps to result.
"""Conformer/Transformer hyps to result.
"""
"""
...
@@ -654,24 +636,28 @@ class PaddleASRConnectionHanddler:
...
@@ -654,24 +636,28 @@ class PaddleASRConnectionHanddler:
# update each word start and end time stamp
# update each word start and end time stamp
# decoding frame to audio frame
# decoding frame to audio frame
frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
decode_frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
frame_shift_in_sec
=
frame_shift
*
(
self
.
n_shift
/
self
.
sample_rate
)
decode_frame_shift_in_sec
=
decode_frame_shift
*
(
self
.
n_shift
/
logger
.
info
(
f
"frame shift sec:
{
frame_shift_in_sec
}
"
)
self
.
sample_rate
)
logger
.
info
(
f
"decode frame shift in sec:
{
decode_frame_shift_in_sec
}
"
)
global_offset_in_sec
=
self
.
global_frame_offset
*
self
.
frame_shift_in_ms
/
1000.0
logger
.
info
(
f
"global offset:
{
global_offset_in_sec
}
sec."
)
word_time_stamp
=
[]
word_time_stamp
=
[]
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
)
/
2.0
if
idx
>
0
else
0
)
/
2.0
if
idx
>
0
else
0
start
=
start
*
frame_shift_in_sec
start
=
start
*
decode_
frame_shift_in_sec
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
end
=
end
*
frame_shift_in_sec
end
=
end
*
decode_
frame_shift_in_sec
word_time_stamp
.
append
({
word_time_stamp
.
append
({
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"bg"
:
start
,
"bg"
:
global_offset_in_sec
+
start
,
"ed"
:
end
"ed"
:
global_offset_in_sec
+
end
})
})
# logger.info(f"{word_time_stamp[-1]}")
# logger.info(f"{word_time_stamp[-1]}")
...
@@ -705,13 +691,14 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -705,13 +691,14 @@ class ASRServerExecutor(ASRExecutor):
self
.
model_type
=
model_type
self
.
model_type
=
model_type
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
self
.
task_resource
.
set_task_model
(
model_tag
=
tag
)
self
.
task_resource
.
set_task_model
(
model_tag
=
tag
)
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
self
.
res_path
=
self
.
task_resource
.
res_dir
self
.
res_path
=
self
.
task_resource
.
res_dir
self
.
cfg_path
=
os
.
path
.
join
(
self
.
cfg_path
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'cfg_path'
])
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'cfg_path'
])
...
@@ -719,7 +706,6 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -719,7 +706,6 @@ class ASRServerExecutor(ASRExecutor):
self
.
task_resource
.
res_dict
[
'model'
])
self
.
task_resource
.
res_dict
[
'model'
])
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'params'
])
self
.
task_resource
.
res_dict
[
'params'
])
logger
.
info
(
self
.
res_path
)
else
:
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
...
@@ -727,9 +713,12 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -727,9 +713,12 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
self
.
am_params
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
info
(
f
" am_params path:
{
self
.
am_params
}
"
)
#Init body.
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
@@ -738,13 +727,18 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -738,13 +727,18 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
text_feature
=
TextFeaturizer
(
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
unit_type
,
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
config
.
vocab_filepath
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
vocab
=
self
.
config
.
vocab_filepath
with
UpdateConfig
(
self
.
config
):
if
"deepspeech2"
in
model_type
:
if
"deepspeech2"
in
model_type
:
with
UpdateConfig
(
self
.
config
):
# download lm
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
MODEL_HOME
,
'language_model'
,
MODEL_HOME
,
'language_model'
,
self
.
config
.
decode
.
lang_model_path
)
self
.
config
.
decode
.
lang_model_path
)
...
@@ -756,7 +750,16 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -756,7 +750,16 @@ class ASRServerExecutor(ASRExecutor):
lm_url
,
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
with
UpdateConfig
(
self
.
config
):
logger
.
info
(
"start to create the stream conformer asr engine"
)
logger
.
info
(
"start to create the stream conformer asr engine"
)
# update the decoding method
# update the decoding method
if
decode_method
:
if
decode_method
:
...
@@ -770,37 +773,24 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -770,37 +773,24 @@ class ASRServerExecutor(ASRExecutor):
logger
.
info
(
logger
.
info
(
"we set the decoding_method to attention_rescoring"
)
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
assert
self
.
config
.
decode
.
decoding_method
in
[
assert
self
.
config
.
decode
.
decoding_method
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
"ctc_prefix_beam_search"
,
"attention_rescoring"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
raise
Exception
(
"wrong type"
)
if
"deepspeech2"
in
model_type
:
# load model
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
model_name
=
model_type
[:
model_type
.
rindex
(
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
'_'
)]
# model_type: {model_name}_{dataset}
logger
.
info
(
f
"model name:
{
model_name
}
"
)
logger
.
info
(
f
"model name:
{
model_name
}
"
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model_conf
=
self
.
config
model
=
model_class
.
from_config
(
self
.
config
)
model
=
model_class
.
from_config
(
model_conf
)
self
.
model
=
model
self
.
model
=
model
self
.
model
.
set_state_dict
(
paddle
.
load
(
self
.
am_model
))
self
.
model
.
eval
()
self
.
model
.
eval
()
# load model
model_dict
=
paddle
.
load
(
self
.
am_model
)
self
.
model
.
set_state_dict
(
model_dict
)
logger
.
info
(
"create the transformer like model success"
)
else
:
else
:
raise
ValueError
(
f
"N
ot support:
{
model_type
}
"
)
raise
Exception
(
f
"n
ot support:
{
model_type
}
"
)
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
return
True
return
True
...
...
paddlespeech/server/engine/asr/online/ctc_endpoint.py
0 → 100644
浏览文件 @
8f9b7bba
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
List
from
paddlespeech.cli.log
import
logger
@
dataclass
class
OnlineCTCEndpointRule
:
must_contain_nonsilence
:
bool
=
True
min_trailing_silence
:
int
=
1000
min_utterance_length
:
int
=
0
@
dataclass
class
OnlineCTCEndpoingOpt
:
frame_shift_in_ms
:
int
=
10
blank
:
int
=
0
# blank id, that we consider as silence for purposes of endpointing.
blank_threshold
:
float
=
0.8
# above blank threshold is silence
# We support three rules. We terminate decoding if ANY of these rules
# evaluates to "true". If you want to add more rules, do it by changing this
# code. If you want to disable a rule, you can set the silence-timeout for
# that rule to a very large number.
# rule1 times out after 5 seconds of silence, even if we decoded nothing.
rule1
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
False
,
5000
,
0
)
# rule4 times out after 1.0 seconds of silence after decoding something,
# even if we did not reach a final-state at all.
rule2
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
True
,
1000
,
0
)
# rule5 times out after the utterance is 20 seconds long, regardless of
# anything else.
rule3
:
OnlineCTCEndpointRule
=
OnlineCTCEndpointRule
(
False
,
0
,
20000
)
class
OnlineCTCEndpoint
:
"""
[END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf)
"""
def
__init__
(
self
,
opts
:
OnlineCTCEndpoingOpt
):
self
.
opts
=
opts
logger
.
info
(
f
"Endpont Opts:
{
opts
}
"
)
self
.
frame_shift_in_ms
=
opts
.
frame_shift_in_ms
self
.
num_frames_decoded
=
0
self
.
trailing_silence_frames
=
0
self
.
reset
()
def
reset
(
self
):
self
.
num_frames_decoded
=
0
self
.
trailing_silence_frames
=
0
def
rule_activated
(
self
,
rule
:
OnlineCTCEndpointRule
,
rule_name
:
str
,
decoding_something
:
bool
,
trailine_silence
:
int
,
utterance_length
:
int
)
->
bool
:
ans
=
(
decoding_something
or
(
not
rule
.
must_contain_nonsilence
)
)
and
trailine_silence
>=
rule
.
min_trailing_silence
and
utterance_length
>=
rule
.
min_utterance_length
if
(
ans
):
logger
.
info
(
f
"Endpoint Rule:
{
rule_name
}
activated:
{
decoding_something
}
,
{
trailine_silence
}
,
{
utterance_length
}
"
)
return
ans
def
endpoint_detected
(
ctc_log_probs
:
List
[
List
[
float
]],
decoding_something
:
bool
)
->
bool
:
for
logprob
in
ctc_log_probs
:
blank_prob
=
exp
(
logprob
[
self
.
opts
.
blank_id
])
self
.
num_frames_decoded
+=
1
if
blank_prob
>
self
.
opts
.
blank_threshold
:
self
.
trailing_silence_frames
+=
1
else
:
self
.
trailing_silence_frames
=
0
assert
self
.
num_frames_decoded
>=
self
.
trailing_silence_frames
assert
self
.
frame_shift_in_ms
>
0
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
if
self
.
rule_activated
(
self
.
opts
.
rule1
,
'rule1'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
if
self
.
rule_activated
(
self
.
opts
.
rule2
,
'rule2'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
if
self
.
rule_activated
(
self
.
opts
.
rule3
,
'rule3'
,
decoding_something
,
trailing_silence
,
utterance_length
):
return
True
return
False
paddlespeech/server/engine/asr/online/ctc_search.py
浏览文件 @
8f9b7bba
...
@@ -30,8 +30,29 @@ class CTCPrefixBeamSearch:
...
@@ -30,8 +30,29 @@ class CTCPrefixBeamSearch:
config (yacs.config.CfgNode): the ctc prefix beam search configuration
config (yacs.config.CfgNode): the ctc prefix beam search configuration
"""
"""
self
.
config
=
config
self
.
config
=
config
# beam size
self
.
first_beam_size
=
self
.
config
.
beam_size
# TODO(support second beam size)
self
.
second_beam_size
=
int
(
self
.
first_beam_size
*
1.0
)
logger
.
info
(
f
"first and second beam size:
{
self
.
first_beam_size
}
,
{
self
.
second_beam_size
}
"
)
# state
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
self
.
reset
()
self
.
reset
()
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
search
(
self
,
ctc_probs
,
device
,
blank_id
=
0
):
def
search
(
self
,
ctc_probs
,
device
,
blank_id
=
0
):
"""ctc prefix beam search method decode a chunk feature
"""ctc prefix beam search method decode a chunk feature
...
@@ -47,12 +68,17 @@ class CTCPrefixBeamSearch:
...
@@ -47,12 +68,17 @@ class CTCPrefixBeamSearch:
"""
"""
# decode
# decode
logger
.
info
(
"start to ctc prefix search"
)
logger
.
info
(
"start to ctc prefix search"
)
assert
len
(
ctc_probs
.
shape
)
==
2
batch_size
=
1
batch_size
=
1
beam_size
=
self
.
config
.
beam_size
maxlen
=
ctc_probs
.
shape
[
0
]
assert
len
(
ctc_probs
.
shape
)
==
2
vocab_size
=
ctc_probs
.
shape
[
1
]
first_beam_size
=
min
(
self
.
first_beam_size
,
vocab_size
)
second_beam_size
=
min
(
self
.
second_beam_size
,
vocab_size
)
logger
.
info
(
f
"effect first and second beam size:
{
self
.
first_beam_size
}
,
{
self
.
second_beam_size
}
"
)
maxlen
=
ctc_probs
.
shape
[
0
]
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# 0. blank_ending_score,
# 0. blank_ending_score,
...
@@ -75,7 +101,8 @@ class CTCPrefixBeamSearch:
...
@@ -75,7 +101,8 @@ class CTCPrefixBeamSearch:
# 2.1 First beam prune: select topk best
# 2.1 First beam prune: select topk best
# do token passing process
# do token passing process
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
top_k_logp
,
top_k_index
=
logp
.
topk
(
first_beam_size
)
# (first_beam_size,)
for
s
in
top_k_index
:
for
s
in
top_k_index
:
s
=
s
.
item
()
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
ps
=
logp
[
s
].
item
()
...
@@ -148,7 +175,7 @@ class CTCPrefixBeamSearch:
...
@@ -148,7 +175,7 @@ class CTCPrefixBeamSearch:
next_hyps
.
items
(),
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
([
x
[
1
][
0
],
x
[
1
][
1
]]),
key
=
lambda
x
:
log_add
([
x
[
1
][
0
],
x
[
1
][
1
]]),
reverse
=
True
)
reverse
=
True
)
self
.
cur_hyps
=
next_hyps
[:
beam_size
]
self
.
cur_hyps
=
next_hyps
[:
second_
beam_size
]
# 2.3 update the absolute time step
# 2.3 update the absolute time step
self
.
abs_time_step
+=
1
self
.
abs_time_step
+=
1
...
@@ -163,7 +190,7 @@ class CTCPrefixBeamSearch:
...
@@ -163,7 +190,7 @@ class CTCPrefixBeamSearch:
"""Return the one best result
"""Return the one best result
Returns:
Returns:
list: the one best result
list: the one best result
, List[str]
"""
"""
return
[
self
.
hyps
[
0
][
0
]]
return
[
self
.
hyps
[
0
][
0
]]
...
@@ -171,17 +198,10 @@ class CTCPrefixBeamSearch:
...
@@ -171,17 +198,10 @@ class CTCPrefixBeamSearch:
"""Return the search hyps
"""Return the search hyps
Returns:
Returns:
list: return the search hyps
list: return the search hyps
, List[Tuple[str, float, ...]]
"""
"""
return
self
.
hyps
return
self
.
hyps
def
reset
(
self
):
"""Rest the search cache value
"""
self
.
cur_hyps
=
None
self
.
hyps
=
None
self
.
abs_time_step
=
0
def
finalize_search
(
self
):
def
finalize_search
(
self
):
"""do nothing in ctc_prefix_beam_search
"""do nothing in ctc_prefix_beam_search
"""
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录