Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
9106daa2
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9106daa2
编写于
6月 16, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code format
上级
42d28b96
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
25 addition
and
17 deletion
+25
-17
demos/streaming_asr_server/conf/ws_ds2_application.yaml
demos/streaming_asr_server/conf/ws_ds2_application.yaml
+1
-1
paddlespeech/resource/resource.py
paddlespeech/resource/resource.py
+3
-1
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
+11
-9
paddlespeech/server/engine/engine_factory.py
paddlespeech/server/engine/engine_factory.py
+3
-0
paddlespeech/server/utils/onnx_infer.py
paddlespeech/server/utils/onnx_infer.py
+5
-4
speechx/examples/ds2_ol/onnx/local/infer_check.py
speechx/examples/ds2_ol/onnx/local/infer_check.py
+2
-2
未找到文件。
demos/streaming_asr_server/conf/ws_ds2_application.yaml
浏览文件 @
9106daa2
...
@@ -11,7 +11,7 @@ port: 8090
...
@@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket'] (only one can be selected).
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
# websocket only support online engine type.
protocol
:
'
websocket'
protocol
:
'
websocket'
engine_list
:
[
'
asr_online-
onnx
'
]
engine_list
:
[
'
asr_online-
inference
'
]
#################################################################################
#################################################################################
...
...
paddlespeech/resource/resource.py
浏览文件 @
9106daa2
...
@@ -164,9 +164,11 @@ class CommonTaskResource:
...
@@ -164,9 +164,11 @@ class CommonTaskResource:
try
:
try
:
import_models
=
'{}_{}_pretrained_models'
.
format
(
self
.
task
,
import_models
=
'{}_{}_pretrained_models'
.
format
(
self
.
task
,
self
.
model_format
)
self
.
model_format
)
print
(
f
"from .pretrained_models import
{
import_models
}
"
)
exec
(
'from .pretrained_models import {}'
.
format
(
import_models
))
exec
(
'from .pretrained_models import {}'
.
format
(
import_models
))
models
=
OrderedDict
(
locals
()[
import_models
])
models
=
OrderedDict
(
locals
()[
import_models
])
except
ImportError
:
except
Exception
as
e
:
print
(
e
)
models
=
OrderedDict
({})
# no models.
models
=
OrderedDict
({})
# no models.
finally
:
finally
:
return
models
return
models
...
...
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
浏览文件 @
9106daa2
...
@@ -306,12 +306,13 @@ class PaddleASRConnectionHanddler:
...
@@ -306,12 +306,13 @@ class PaddleASRConnectionHanddler:
assert
(
len
(
input_names
)
==
len
(
output_names
))
assert
(
len
(
input_names
)
==
len
(
output_names
))
assert
isinstance
(
input_names
[
0
],
str
)
assert
isinstance
(
input_names
[
0
],
str
)
input_datas
=
[
self
.
chunk_state_c_box
,
self
.
chunk_state_h_box
,
x_chunk_lens
,
x_chunk
]
input_datas
=
[
self
.
chunk_state_c_box
,
self
.
chunk_state_h_box
,
x_chunk_lens
,
x_chunk
]
feeds
=
dict
(
zip
(
input_names
,
input_datas
))
feeds
=
dict
(
zip
(
input_names
,
input_datas
))
outputs
=
self
.
am_predictor
.
run
(
outputs
=
self
.
am_predictor
.
run
([
*
output_names
],
{
**
feeds
})
[
*
output_names
],
{
**
feeds
})
output_chunk_probs
,
output_chunk_lens
,
self
.
chunk_state_h_box
,
self
.
chunk_state_c_box
=
outputs
output_chunk_probs
,
output_chunk_lens
,
self
.
chunk_state_h_box
,
self
.
chunk_state_c_box
=
outputs
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
...
@@ -335,7 +336,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -335,7 +336,7 @@ class ASRServerExecutor(ASRExecutor):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
task_resource
=
CommonTaskResource
(
self
.
task_resource
=
CommonTaskResource
(
task
=
'asr'
,
model_format
=
'
static
'
,
inference_mode
=
'online'
)
task
=
'asr'
,
model_format
=
'
onnx
'
,
inference_mode
=
'online'
)
def
update_config
(
self
)
->
None
:
def
update_config
(
self
)
->
None
:
if
"deepspeech2"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
...
@@ -407,10 +408,11 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -407,10 +408,11 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
self
.
am_model
=
os
.
path
.
join
(
self
.
res_path
,
self
.
am_model
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
self
.
task_resource
.
res_dict
[
'model'
])
if
am_model
is
None
else
os
.
path
.
abspath
(
am_model
)
'model'
])
if
am_model
is
None
else
os
.
path
.
abspath
(
am_model
)
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
am_params
=
os
.
path
.
join
(
self
.
task_resource
.
res_dict
[
'params'
])
if
am_params
is
None
else
os
.
path
.
abspath
(
am_params
)
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'params'
])
if
am_params
is
None
else
os
.
path
.
abspath
(
am_params
)
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
...
...
paddlespeech/server/engine/engine_factory.py
浏览文件 @
9106daa2
...
@@ -12,14 +12,17 @@
...
@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Text
from
typing
import
Text
from
..utils.log
import
logger
from
..utils.log
import
logger
__all__
=
[
'EngineFactory'
]
__all__
=
[
'EngineFactory'
]
class
EngineFactory
(
object
):
class
EngineFactory
(
object
):
@
staticmethod
@
staticmethod
def
get_engine
(
engine_name
:
Text
,
engine_type
:
Text
):
def
get_engine
(
engine_name
:
Text
,
engine_type
:
Text
):
logger
.
info
(
f
"
{
engine_name
}
:
{
engine_type
}
engine."
)
logger
.
info
(
f
"
{
engine_name
}
:
{
engine_type
}
engine."
)
if
engine_name
==
'asr'
and
engine_type
==
'inference'
:
if
engine_name
==
'asr'
and
engine_type
==
'inference'
:
from
paddlespeech.server.engine.asr.paddleinference.asr_engine
import
ASREngine
from
paddlespeech.server.engine.asr.paddleinference.asr_engine
import
ASREngine
return
ASREngine
()
return
ASREngine
()
...
...
paddlespeech/server/utils/onnx_infer.py
浏览文件 @
9106daa2
...
@@ -35,14 +35,15 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
...
@@ -35,14 +35,15 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
if
sess_conf
.
get
(
"use_trt"
,
0
):
if
sess_conf
.
get
(
"use_trt"
,
0
):
providers
=
[
'TensorrtExecutionProvider'
]
providers
=
[
'TensorrtExecutionProvider'
]
logger
.
info
(
f
"ort providers:
{
providers
}
"
)
logger
.
info
(
f
"ort providers:
{
providers
}
"
)
if
'cpu_threads'
in
sess_conf
:
if
'cpu_threads'
in
sess_conf
:
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"cpu_threads"
,
0
)
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"cpu_threads"
,
0
)
else
:
else
:
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"intra_op_num_threads"
,
0
)
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"intra_op_num_threads"
,
0
)
sess_options
.
inter_op_num_threads
=
sess_conf
.
get
(
"inter_op_num_threads"
,
0
)
sess_options
.
inter_op_num_threads
=
sess_conf
.
get
(
"inter_op_num_threads"
,
0
)
sess
=
ort
.
InferenceSession
(
sess
=
ort
.
InferenceSession
(
model_path
,
providers
=
providers
,
sess_options
=
sess_options
)
model_path
,
providers
=
providers
,
sess_options
=
sess_options
)
return
sess
return
sess
speechx/examples/ds2_ol/onnx/local/infer_check.py
浏览文件 @
9106daa2
...
@@ -27,7 +27,8 @@ def parse_args():
...
@@ -27,7 +27,8 @@ def parse_args():
'--input_file'
,
'--input_file'
,
type
=
str
,
type
=
str
,
default
=
"static_ds2online_inputs.pickle"
,
default
=
"static_ds2online_inputs.pickle"
,
help
=
"aishell ds2 input data file. For wenetspeech, we only feed for infer model"
,
)
help
=
"aishell ds2 input data file. For wenetspeech, we only feed for infer model"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
'--model_type'
,
'--model_type'
,
type
=
str
,
type
=
str
,
...
@@ -57,7 +58,6 @@ if __name__ == '__main__':
...
@@ -57,7 +58,6 @@ if __name__ == '__main__':
iodict
=
pickle
.
load
(
f
)
iodict
=
pickle
.
load
(
f
)
print
(
iodict
.
keys
())
print
(
iodict
.
keys
())
audio_chunk
=
iodict
[
'audio_chunk'
]
audio_chunk
=
iodict
[
'audio_chunk'
]
audio_chunk_lens
=
iodict
[
'audio_chunk_lens'
]
audio_chunk_lens
=
iodict
[
'audio_chunk_lens'
]
chunk_state_h_box
=
iodict
[
'chunk_state_h_box'
]
chunk_state_h_box
=
iodict
[
'chunk_state_h_box'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录