Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
3cee7db0
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3cee7db0
编写于
6月 15, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
onnx ds2 straming asr
上级
c8574c7e
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
638 addition
and
21 deletion
+638
-21
demos/streaming_asr_server/conf/ws_ds2_application.yaml
demos/streaming_asr_server/conf/ws_ds2_application.yaml
+39
-4
paddlespeech/resource/pretrained_models.py
paddlespeech/resource/pretrained_models.py
+16
-0
paddlespeech/server/conf/ws_ds2_application.yaml
paddlespeech/server/conf/ws_ds2_application.yaml
+42
-7
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
+520
-0
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
...ch/server/engine/asr/online/paddleinference/asr_engine.py
+0
-1
paddlespeech/server/engine/asr/online/python/asr_engine.py
paddlespeech/server/engine/asr/online/python/asr_engine.py
+0
-1
paddlespeech/server/engine/engine_factory.py
paddlespeech/server/engine/engine_factory.py
+2
-1
paddlespeech/server/utils/onnx_infer.py
paddlespeech/server/utils/onnx_infer.py
+19
-7
未找到文件。
demos/streaming_asr_server/conf/ws_ds2_application.yaml
浏览文件 @
3cee7db0
...
...
@@ -7,11 +7,11 @@ host: 0.0.0.0
port
:
8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online']
# task choices = ['asr_online
-inference', 'asr_online-onnx
']
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol
:
'
websocket'
engine_list
:
[
'
asr_online-
inference
'
]
engine_list
:
[
'
asr_online-
onnx
'
]
#################################################################################
...
...
@@ -19,7 +19,7 @@ engine_list: ['asr_online-inference']
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
################### speech task: asr; engine_type: online
-inference
#######################
asr_online-inference
:
model_type
:
'
deepspeech2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
...
...
@@ -47,3 +47,38 @@ asr_online-inference:
shift_n
:
4
# frame
window_ms
:
20
# ms
shift_ms
:
10
# ms
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx
:
model_type
:
'
deepspeech2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
sample_rate
:
16000
cfg_path
:
decode_method
:
num_decoding_left_chunks
:
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf
:
device
:
'
cpu'
# set 'gpu:id' or 'cpu'
graph_optimization_level
:
0
intra_op_num_threads
:
0
# Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads
:
0
# Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level
:
2
# Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level
:
0
# VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf
:
frame_duration_ms
:
80
shift_ms
:
40
sample_rate
:
16000
sample_width
:
2
window_n
:
7
# frame
shift_n
:
4
# frame
window_ms
:
20
# ms
shift_ms
:
10
# ms
paddlespeech/resource/pretrained_models.py
浏览文件 @
3cee7db0
...
...
@@ -15,6 +15,7 @@
__all__
=
[
'asr_dynamic_pretrained_models'
,
'asr_static_pretrained_models'
,
'asr_onnx_pretrained_models'
,
'cls_dynamic_pretrained_models'
,
'cls_static_pretrained_models'
,
'st_dynamic_pretrained_models'
,
...
...
@@ -246,6 +247,21 @@ asr_static_pretrained_models = {
},
}
asr_onnx_pretrained_models
=
{
"deepspeech2online_wenetspeech-zh-16k"
:
{
'1.0'
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz'
,
'md5'
:
'b0c77e7f8881e0a27b82127d1abb8d5f'
,
'cfg_path'
:
'model.yaml'
,
'ckpt_path'
:
'exp/deepspeech2_online/checkpoints/avg_10'
,
'lm_url'
:
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
,
'lm_md5'
:
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}
# ---------------------------------
# -------------- CLS --------------
# ---------------------------------
...
...
paddlespeech/server/conf/ws_ds2_application.yaml
浏览文件 @
3cee7db0
...
...
@@ -7,11 +7,11 @@ host: 0.0.0.0
port
:
8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online
', 'tts_online
']
# protocol = ['websocket'
, 'http'
] (only one can be selected).
# task choices = ['asr_online
-inference', 'asr_online-onnx
']
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol
:
'
websocket'
engine_list
:
[
'
asr_online-
inference
'
]
engine_list
:
[
'
asr_online-
onnx
'
]
#################################################################################
...
...
@@ -19,7 +19,7 @@ engine_list: ['asr_online-inference']
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
################### speech task: asr; engine_type: online
-inference
#######################
asr_online-inference
:
model_type
:
'
deepspeech2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
...
...
@@ -30,7 +30,7 @@ asr_online-inference:
decode_method
:
num_decoding_left_chunks
:
force_yes
:
True
device
:
# cpu or gpu:id
device
:
'
cpu'
# cpu or gpu:id
am_predictor_conf
:
device
:
# set 'gpu:id' or 'cpu'
...
...
@@ -47,3 +47,38 @@ asr_online-inference:
shift_n
:
4
# frame
window_ms
:
20
# ms
shift_ms
:
10
# ms
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx
:
model_type
:
'
deepspeech2online_aishell'
am_model
:
# the pdmodel file of am static model [optional]
am_params
:
# the pdiparams file of am static model [optional]
lang
:
'
zh'
sample_rate
:
16000
cfg_path
:
decode_method
:
num_decoding_left_chunks
:
force_yes
:
True
device
:
'
cpu'
# cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf
:
device
:
'
cpu'
# set 'gpu:id' or 'cpu'
graph_optimization_level
:
0
intra_op_num_threads
:
0
# Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads
:
0
# Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level
:
2
# Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level
:
0
# VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf
:
frame_duration_ms
:
80
shift_ms
:
40
sample_rate
:
16000
sample_width
:
2
window_n
:
7
# frame
shift_n
:
4
# frame
window_ms
:
20
# ms
shift_ms
:
10
# ms
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
0 → 100644
浏览文件 @
3cee7db0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
typing
import
ByteString
from
typing
import
Optional
import
numpy
as
np
import
paddle
from
numpy
import
float32
from
yacs.config
import
CfgNode
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils
import
onnx_infer
__all__
=
[
'PaddleASRConnectionHanddler'
,
'ASRServerExecutor'
,
'ASREngine'
]
# ASR server connection process class
class
PaddleASRConnectionHanddler
:
def
__init__
(
self
,
asr_engine
):
"""Init a Paddle ASR Connection Handler instance
Args:
asr_engine (ASREngine): the global asr engine
"""
super
().
__init__
()
logger
.
info
(
"create an paddle asr connection handler to process the websocket connection"
)
self
.
config
=
asr_engine
.
config
# server config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
asr_engine
=
asr_engine
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
# extract feat, new only fbank in conformer model
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
assert
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
==
self
.
sample_rate
,
(
self
.
sample_rate
,
self
.
preprocess_conf
.
process
[
0
][
'fs'
])
self
.
frame_shift_in_ms
=
int
(
self
.
n_shift
/
self
.
preprocess_conf
.
process
[
0
][
'fs'
]
*
1000
)
self
.
continuous_decoding
=
self
.
config
.
get
(
"continuous_decoding"
,
False
)
self
.
init_decoder
()
self
.
reset
()
def
init_decoder
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
assert
self
.
continuous_decoding
is
False
,
"ds2 model not support endpoint"
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
decoder
=
CTCDecoder
(
odim
=
self
.
model_config
.
output_dim
,
# <blank> is in vocab
enc_n_units
=
self
.
model_config
.
rnn_layer_size
*
2
,
blank_id
=
self
.
model_config
.
blank_id
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
self
.
model_config
.
get
(
'ctc_grad_norm_type'
,
None
))
cfg
=
self
.
model_config
.
decode
decode_batch_size
=
1
# for online
self
.
decoder
.
init_decoder
(
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
else
:
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
def
model_reset
(
self
):
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
def
output_reset
(
self
):
## outputs
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
def
reset_continuous_decoding
(
self
):
"""
when in continous decoding, reset for next utterance.
"""
self
.
global_frame_offset
=
self
.
num_frames
self
.
model_reset
()
def
reset
(
self
):
if
"deepspeech2"
in
self
.
model_type
:
# for deepspeech2
# init state
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
model_config
.
num_rnn_layers
,
1
,
self
.
model_config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
model_config
.
num_rnn_layers
,
1
,
self
.
model_config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
else
:
raise
NotImplementedError
(
f
"
{
self
.
model_type
}
not support."
)
self
.
device
=
None
## common
# global sample and frame step
self
.
num_samples
=
0
self
.
global_frame_offset
=
0
# frame step of cur utterance
self
.
num_frames
=
0
## endpoint
self
.
endpoint_state
=
False
# True for detect endpoint
## conformer
self
.
model_reset
()
## outputs
self
.
output_reset
()
def
extract_feat
(
self
,
samples
:
ByteString
):
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
# samples not enough for feature window
return
0
# fbank
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# feature cache
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
# (B,T,D)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
# (B,T,D)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
# cur frame step
num_frames
=
x_chunk
.
shape
[
1
]
# global frame step
self
.
num_frames
+=
num_frames
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Returns:
None:
"""
if
"deepspeech2"
in
self
.
model_type
:
decoding_chunk_size
=
1
# decoding chunk size = 1. int decoding frame unit
context
=
7
# context=7, in audio frame unit
subsampling
=
4
# subsampling=4, in audio frame unit
cached_feature_num
=
context
-
subsampling
# decoding window for model, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
# decoding stride for model, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
logger
.
info
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
# if get the finished chunk, we need process the last context
left_frames
=
context
else
:
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
x_chunk
=
self
.
cached_feat
[:,
cur
:
end
,
:].
numpy
()
x_chunk_lens
=
np
.
array
([
x_chunk
.
shape
[
1
]])
trans_best
=
self
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
)
self
.
result_transcripts
=
[
trans_best
]
# update feat cache
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
# return trans_best[0]
else
:
raise
Exception
(
f
"
{
self
.
model_type
}
not support paddleinference."
)
@
paddle
.
no_grad
()
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
):
"""forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger
.
info
(
"start to decoce one chunk for deepspeech2"
)
# state_c, state_h, audio_lens, audio
# 'chunk_state_c_box', 'chunk_state_h_box', 'audio_chunk_lens', 'audio_chunk'
input_names
=
[
n
.
name
for
n
in
self
.
am_predictor
.
get_inputs
()]
logger
.
info
(
f
"ort inputs:
{
input_names
}
"
)
# 'softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'
# audio, audio_lens, state_h, state_c
output_names
=
[
n
.
name
for
n
in
self
.
am_predictor
.
get_outputs
()]
logger
.
info
(
f
"ort outpus:
{
output_names
}
"
)
assert
(
len
(
input_names
)
==
len
(
output_names
))
assert
isinstance
(
input_names
[
0
],
str
)
input_datas
=
[
self
.
chunk_state_c_box
,
self
.
chunk_state_h_box
,
x_chunk_lens
,
x_chunk
]
feeds
=
dict
(
zip
(
input_names
,
input_datas
))
outputs
=
self
.
am_predictor
.
run
(
[
*
output_names
],
{
**
feeds
})
output_chunk_probs
,
output_chunk_lens
,
self
.
chunk_state_h_box
,
self
.
chunk_state_c_box
=
outputs
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result for deepspeech2:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
def
get_result
(
self
):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if
len
(
self
.
result_transcripts
)
>
0
:
return
self
.
result_transcripts
[
0
]
else
:
return
''
class
ASRServerExecutor
(
ASRExecutor
):
def
__init__
(
self
):
super
().
__init__
()
self
.
task_resource
=
CommonTaskResource
(
task
=
'asr'
,
model_format
=
'static'
,
inference_mode
=
'online'
)
def
update_config
(
self
)
->
None
:
if
"deepspeech2"
in
self
.
model_type
:
with
UpdateConfig
(
self
.
config
):
# download lm
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
MODEL_HOME
,
'language_model'
,
self
.
config
.
decode
.
lang_model_path
)
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
else
:
raise
NotImplementedError
(
f
"
{
self
.
model_type
}
not support paddleinference."
)
def
init_model
(
self
)
->
None
:
if
"deepspeech2"
in
self
.
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor
=
onnx_infer
.
get_sess
(
model_path
=
self
.
am_model
,
sess_conf
=
self
.
am_predictor_conf
)
else
:
raise
NotImplementedError
(
f
"
{
self
.
model_type
}
not support paddleinference."
)
def
_init_from_path
(
self
,
model_type
:
str
=
None
,
am_model
:
Optional
[
os
.
PathLike
]
=
None
,
am_params
:
Optional
[
os
.
PathLike
]
=
None
,
lang
:
str
=
'zh'
,
sample_rate
:
int
=
16000
,
cfg_path
:
Optional
[
os
.
PathLike
]
=
None
,
decode_method
:
str
=
'attention_rescoring'
,
num_decoding_left_chunks
:
int
=-
1
,
am_predictor_conf
:
dict
=
None
):
"""
Init model and other resources from a specific path.
"""
if
not
model_type
or
not
lang
or
not
sample_rate
:
logger
.
error
(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return
False
assert
am_params
is
None
,
"am_params not used in onnx engine"
self
.
model_type
=
model_type
self
.
sample_rate
=
sample_rate
self
.
decode_method
=
decode_method
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self
.
am_predictor_conf
=
am_predictor_conf
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
self
.
task_resource
.
set_task_model
(
model_tag
=
tag
)
if
cfg_path
is
None
:
self
.
res_path
=
self
.
task_resource
.
res_dir
self
.
cfg_path
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'cfg_path'
])
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
self
.
am_model
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'model'
])
if
am_model
is
None
else
os
.
path
.
abspath
(
am_model
)
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'params'
])
if
am_params
is
None
else
os
.
path
.
abspath
(
am_params
)
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
info
(
f
" am_params path:
{
self
.
am_params
}
"
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
update_config
()
# AM predictor
self
.
init_model
()
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
return
True
class
ASREngine
(
BaseEngine
):
"""ASR model resource
Args:
metaclass: Defaults to Singleton.
"""
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
def
init_model
(
self
)
->
bool
:
if
not
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model_type
,
am_model
=
self
.
config
.
am_model
,
am_params
=
self
.
config
.
am_params
,
lang
=
self
.
config
.
lang
,
sample_rate
=
self
.
config
.
sample_rate
,
cfg_path
=
self
.
config
.
cfg_path
,
decode_method
=
self
.
config
.
decode_method
,
num_decoding_left_chunks
=
self
.
config
.
num_decoding_left_chunks
,
am_predictor_conf
=
self
.
config
.
am_predictor_conf
):
return
False
return
True
def
init
(
self
,
config
:
dict
)
->
bool
:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self
.
config
=
config
self
.
executor
=
ASRServerExecutor
()
try
:
self
.
device
=
self
.
config
.
get
(
"device"
,
paddle
.
get_device
())
paddle
.
set_device
(
self
.
device
)
except
BaseException
as
e
:
logger
.
error
(
f
"Set device failed, please check if device '
{
self
.
device
}
' is already used and the parameter 'device' in the yaml file"
)
logger
.
error
(
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
init_model
():
logger
.
error
(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return
False
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
def
new_handler
(
self
):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return
PaddleASRConnectionHanddler
(
self
)
def
preprocess
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
def
run
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
def
postprocess
(
self
):
raise
NotImplementedError
(
"Online not using this."
)
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
浏览文件 @
3cee7db0
...
...
@@ -471,7 +471,6 @@ class ASREngine(BaseEngine):
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
logger
.
info
(
"create the online asr engine resource instance"
)
def
init_model
(
self
)
->
bool
:
if
not
self
.
executor
.
_init_from_path
(
...
...
paddlespeech/server/engine/asr/online/python/asr_engine.py
浏览文件 @
3cee7db0
...
...
@@ -845,7 +845,6 @@ class ASREngine(BaseEngine):
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
logger
.
info
(
"create the online asr engine resource instance"
)
def
init_model
(
self
)
->
bool
:
if
not
self
.
executor
.
_init_from_path
(
...
...
paddlespeech/server/engine/engine_factory.py
浏览文件 @
3cee7db0
...
...
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Text
from
..utils.log
import
logger
__all__
=
[
'EngineFactory'
]
class
EngineFactory
(
object
):
@
staticmethod
def
get_engine
(
engine_name
:
Text
,
engine_type
:
Text
):
logger
.
info
(
f
"
{
engine_name
}
:
{
engine_type
}
engine."
)
if
engine_name
==
'asr'
and
engine_type
==
'inference'
:
from
paddlespeech.server.engine.asr.paddleinference.asr_engine
import
ASREngine
return
ASREngine
()
...
...
paddlespeech/server/utils/onnx_infer.py
浏览文件 @
3cee7db0
...
...
@@ -16,21 +16,33 @@ from typing import Optional
import
onnxruntime
as
ort
from
.log
import
logger
def
get_sess
(
model_path
:
Optional
[
os
.
PathLike
]
=
None
,
sess_conf
:
dict
=
None
):
logger
.
info
(
f
"ort sessconf:
{
sess_conf
}
"
)
sess_options
=
ort
.
SessionOptions
()
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
if
sess_conf
.
get
(
'graph_optimization_level'
,
99
)
==
0
:
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_DISABLE_ALL
sess_options
.
execution_mode
=
ort
.
ExecutionMode
.
ORT_SEQUENTIAL
if
"gpu"
in
sess_conf
[
"device"
]:
# "gpu:0"
providers
=
[
'CPUExecutionProvider'
]
if
"gpu"
in
sess_conf
.
get
(
"device"
,
""
):
providers
=
[
'CUDAExecutionProvider'
]
# fastspeech2/mb_melgan can't use trt now!
if
sess_conf
[
"use_trt"
]
:
if
sess_conf
.
get
(
"use_trt"
,
0
)
:
providers
=
[
'TensorrtExecutionProvider'
]
logger
.
info
(
f
"ort providers:
{
providers
}
"
)
if
'cpu_threads'
in
sess_conf
:
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"cpu_threads"
,
0
)
else
:
providers
=
[
'CUDAExecutionProvider'
]
elif
sess_conf
[
"device"
]
==
"cpu"
:
providers
=
[
'CPUExecutionProvider'
]
sess_options
.
intra_op_num_threads
=
sess_conf
[
"cpu_threads"
]
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"intra_op_num_threads"
,
0
)
sess_options
.
inter_op_num_threads
=
sess_conf
.
get
(
"inter_op_num_threads"
,
0
)
sess
=
ort
.
InferenceSession
(
model_path
,
providers
=
providers
,
sess_options
=
sess_options
)
return
sess
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录