Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f56dba0c
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
10 个月 前同步成功
通知
200
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f56dba0c
编写于
4月 19, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the code format, test=doc
上级
380afbbc
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
64 addition
and
63 deletion
+64
-63
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+1
-1
paddlespeech/server/conf/ws_conformer_application.yaml
paddlespeech/server/conf/ws_conformer_application.yaml
+1
-1
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+62
-61
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
f56dba0c
...
...
@@ -129,7 +129,7 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"
,
"conformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"conformer
2
online"
:
"conformer
_
online"
:
"paddlespeech.s2t.models.u2:U2Model"
,
"transformer"
:
"paddlespeech.s2t.models.u2:U2Model"
,
...
...
paddlespeech/server/conf/ws_conformer_application.yaml
浏览文件 @
f56dba0c
...
...
@@ -21,7 +21,7 @@ engine_list: ['asr_online']
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online
:
model_type
:
'
conformer
2online_aishell
'
model_type
:
'
conformer
_online_multi-cn
'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
...
...
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
f56dba0c
...
...
@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
os
from
typing
import
Optional
import
copy
import
numpy
as
np
import
paddle
from
numpy
import
float32
...
...
@@ -58,7 +59,7 @@ pretrained_models = {
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer
2online_aishell
-zh-16k"
:
{
"conformer
_online_multi-cn
-zh-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz'
,
'md5'
:
...
...
@@ -93,19 +94,22 @@ class PaddleASRConnectionHanddler:
)
self
.
config
=
asr_engine
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
# self.model = asr_engine.executor.model
self
.
asr_engine
=
asr_engine
self
.
init
()
self
.
reset
()
def
init
(
self
):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
self
.
model_config
)
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
model_config
.
output_dim
,
# <blank> is in vocab
...
...
@@ -114,7 +118,8 @@ class PaddleASRConnectionHanddler:
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
self
.
model_config
.
get
(
'ctc_grad_norm_type'
,
None
))
grad_norm_type
=
self
.
model_config
.
get
(
'ctc_grad_norm_type'
,
None
))
cfg
=
self
.
model_config
.
decode
decode_batch_size
=
1
# for online
...
...
@@ -123,20 +128,16 @@ class PaddleASRConnectionHanddler:
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# frame window samples length and frame shift samples length
self
.
win_length
=
int
(
self
.
model_config
.
window_ms
*
self
.
sample_rate
)
self
.
n_shift
=
int
(
self
.
model_config
.
stride_ms
*
self
.
sample_rate
)
# frame window samples length and frame shift samples length
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
self
.
win_length
=
int
(
self
.
model_config
.
window_ms
*
self
.
sample_rate
)
self
.
n_shift
=
int
(
self
.
model_config
.
stride_ms
*
self
.
sample_rate
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
# ctc decoding config
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
ctc_decode_config
)
...
...
@@ -189,7 +190,7 @@ class PaddleASRConnectionHanddler:
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
audio
else
:
...
...
@@ -211,7 +212,7 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
elif
"conformer
2
online"
in
self
.
model_type
:
elif
"conformer
_
online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
...
...
@@ -264,41 +265,43 @@ class PaddleASRConnectionHanddler:
def
reset
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
# for deepspeech2
self
.
chunk_state_h_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_h_box
)
self
.
chunk_state_c_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_c_box
)
self
.
chunk_state_h_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_h_box
)
self
.
chunk_state_c_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_c_box
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
# for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
num_samples
=
0
self
.
device
=
None
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
# for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
num_samples
=
0
self
.
device
=
None
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
def
decode
(
self
,
is_finished
=
False
):
if
"deepspeech2online"
in
self
.
model_type
:
# x_chunk 是特征数据
decoding_chunk_size
=
1
# decoding_chunk_size=1 in deepspeech2 model
context
=
7
# context=7 in deepspeech2 model
subsampling
=
4
# subsampling=4 in deepspeech2 model
decoding_chunk_size
=
1
# decoding_chunk_size=1 in deepspeech2 model
context
=
7
# context=7 in deepspeech2 model
subsampling
=
4
# subsampling=4 in deepspeech2 model
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
...
...
@@ -306,14 +309,14 @@ class PaddleASRConnectionHanddler:
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
logger
.
info
(
"start to do model forward"
)
...
...
@@ -334,8 +337,7 @@ class PaddleASRConnectionHanddler:
self
.
result_transcripts
=
[
trans_best
]
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
# return trans_best[0]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
try
:
...
...
@@ -354,8 +356,7 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
"start to decoce one chunk with deepspeech2 model"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
3
])
...
...
@@ -374,11 +375,11 @@ class PaddleASRConnectionHanddler:
output_names
=
self
.
am_predictor
.
get_output_names
()
output_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
1
])
output_names
[
1
])
output_state_h_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
2
])
output_names
[
2
])
output_state_c_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
3
])
output_names
[
3
])
self
.
am_predictor
.
run
()
...
...
@@ -389,7 +390,7 @@ class PaddleASRConnectionHanddler:
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one
one
best result:
{
trans_best
[
0
]
}
"
)
logger
.
info
(
f
"decode one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
def
advance_decoding
(
self
,
is_finished
=
False
):
...
...
@@ -500,7 +501,7 @@ class PaddleASRConnectionHanddler:
def
rescoring
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
return
logger
.
info
(
"rescoring the final result"
)
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
return
...
...
@@ -587,7 +588,7 @@ class ASRServerExecutor(ASRExecutor):
return
decompressed_path
def
_init_from_path
(
self
,
model_type
:
str
=
'
wenetspeech
'
,
model_type
:
str
=
'
deepspeech2online_aishell
'
,
am_model
:
Optional
[
os
.
PathLike
]
=
None
,
am_params
:
Optional
[
os
.
PathLike
]
=
None
,
lang
:
str
=
'zh'
,
...
...
@@ -647,7 +648,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
logger
.
info
(
"start to create the stream conformer asr engine"
)
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
...
...
@@ -711,7 +712,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
or
"wenetspeech"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
logger
.
info
(
f
"model name:
{
model_name
}
"
)
...
...
@@ -742,7 +743,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
self
.
transformer_decode_reset
()
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
,
model_type
:
str
):
...
...
@@ -754,7 +755,7 @@ class ASRServerExecutor(ASRExecutor):
model_type (str): online model type
Returns:
[type]: [description]
str: one best result
"""
logger
.
info
(
"start to decoce chunk by chunk"
)
if
"deepspeech2online"
in
model_type
:
...
...
@@ -795,7 +796,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one
one
best result:
{
trans_best
[
0
]
}
"
)
logger
.
info
(
f
"decode one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
...
...
@@ -972,7 +973,7 @@ class ASRServerExecutor(ASRExecutor):
x_chunk_lens
=
np
.
array
([
audio_len
])
return
x_chunk
,
x_chunk_lens
elif
"conformer
2
online"
in
self
.
model_type
:
elif
"conformer
_
online"
in
self
.
model_type
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
...
...
@@ -1005,7 +1006,7 @@ class ASREngine(BaseEngine):
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
logger
.
info
(
"create the online asr engine insta
ch
e"
)
logger
.
info
(
"create the online asr engine insta
nc
e"
)
def
init
(
self
,
config
:
dict
)
->
bool
:
"""init engine resource
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录