Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
14febbe5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
14febbe5
编写于
2月 07, 2022
作者:
L
lym0302
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add paddle inference code, test=doc
上级
5f3193e9
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
596 addition
and
6 deletion
+596
-6
speechserving/speechserving/conf/tts/tts_pd.yaml
speechserving/speechserving/conf/tts/tts_pd.yaml
+45
-0
speechserving/speechserving/engine/tts/paddleinference/tts_engine.py
...ng/speechserving/engine/tts/paddleinference/tts_engine.py
+463
-0
speechserving/speechserving/utils/audio_process.py
speechserving/speechserving/utils/audio_process.py
+6
-6
speechserving/speechserving/utils/paddle_predictor.py
speechserving/speechserving/utils/paddle_predictor.py
+82
-0
未找到文件。
speechserving/speechserving/conf/tts/tts_pd.yaml
0 → 100644
浏览文件 @
14febbe5
# This is the parameter configuration file for TTS server.
# These are the static models that support paddle inference.
##################################################################
# TTS SERVER SETTING #
##################################################################
host
:
'
0.0.0.0'
port
:
8692
##################################################################
# ACOUSTIC MODEL SETTING #
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
##################################################################
am
:
'
fastspeech2_csmsc'
am_model
:
am_params
:
phones_dict
:
'
./dict_dir/phone_id_map.txt'
tones_dict
:
speaker_dict
:
spk_id
:
0
am_predictor_conf
:
use_gpu
:
'
true'
enable_mkldnn
:
'
true'
switch_ir_optim
:
'
true'
##################################################################
# VOCODER SETTING #
# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
##################################################################
voc
:
'
pwgan_csmsc'
voc_model
:
voc_params
:
voc_predictor_conf
:
use_gpu
:
'
true'
enable_mkldnn
:
'
true'
switch_ir_optim
:
'
true'
##################################################################
# OTHERS #
##################################################################
lang
:
'
zh'
device
:
paddle.get_device()
\ No newline at end of file
speechserving/speechserving/engine/tts/paddleinference/tts_engine.py
0 → 100644
浏览文件 @
14febbe5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
base64
import
io
import
os
from
typing
import
Optional
import
librosa
import
numpy
as
np
import
paddle
import
soundfile
as
sf
import
yaml
from
engine.base_engine
import
BaseEngine
from
scipy.io
import
wavfile
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.tts.infer
import
TTSExecutor
from
paddlespeech.cli.utils
import
download_and_decompress
from
paddlespeech.cli.utils
import
MODEL_HOME
from
paddlespeech.t2s.frontend
import
English
from
paddlespeech.t2s.frontend.zh_frontend
import
Frontend
from
utils.audio_process
import
change_speed
from
utils.errors
import
ErrorCode
from
utils.exception
import
ServerBaseException
from
utils.paddle_predictor
import
init_predictor
from
utils.paddle_predictor
import
run_model
#from paddle.inference import Config
#from paddle.inference import create_predictor
__all__
=
[
'TTSEngine'
]
# Static model applied on paddle inference
pretrained_models
=
{
# speedyspeech
"speedyspeech_csmsc-zh"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip'
,
'md5'
:
'9a849a74d1be0c758dd5a1b9c8f77f3d'
,
'model'
:
'speedyspeech_csmsc.pdmodel'
,
'params'
:
'speedyspeech_csmsc.pdiparams'
,
'phones_dict'
:
'phone_id_map.txt'
,
'tones_dict'
:
'tone_id_map.txt'
,
},
# fastspeech2
"fastspeech2_csmsc-zh"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip'
,
'md5'
:
'8eb01c2e4bc7e8b59beaa9fa046069cf'
,
'model'
:
'fastspeech2_csmsc.pdmodel'
,
'params'
:
'fastspeech2_csmsc.pdiparams'
,
'phones_dict'
:
'phone_id_map.txt'
,
},
# pwgan
"pwgan_csmsc-zh"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip'
,
'md5'
:
'e3504aed9c5a290be12d1347836d2742'
,
'model'
:
'pwgan_csmsc.pdmodel'
,
'params'
:
'pwgan_csmsc.pdiparams'
,
},
# mb_melgan
"mb_melgan_csmsc-zh"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip'
,
'md5'
:
'ac6eee94ba483421d750433f4c3b8d36'
,
'model'
:
'mb_melgan_csmsc.pdmodel'
,
'params'
:
'mb_melgan_csmsc.pdiparams'
,
},
# hifigan
"hifigan_csmsc-zh"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip'
,
'md5'
:
'7edd8c436b3a5546b3a7cb8cff9d5a0c'
,
'model'
:
'hifigan_csmsc.pdmodel'
,
'params'
:
'hifigan_csmsc.pdiparams'
,
},
}
class
TTSServerExecutor
(
TTSExecutor
):
def
__init__
(
self
):
super
().
__init__
()
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
'paddlespeech.tts'
,
add_help
=
True
)
self
.
parser
.
add_argument
(
'--conf'
,
type
=
str
,
default
=
'./conf/tts/tts_pd.yaml'
,
help
=
'Configuration parameters.'
)
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
"""
Download and returns pretrained resources path of current task.
"""
assert
tag
in
pretrained_models
,
'Can not find pretrained resources of {}.'
.
format
(
tag
)
res_path
=
os
.
path
.
join
(
MODEL_HOME
,
tag
)
decompressed_path
=
download_and_decompress
(
pretrained_models
[
tag
],
res_path
)
decompressed_path
=
os
.
path
.
abspath
(
decompressed_path
)
logger
.
info
(
'Use pretrained model stored in: {}'
.
format
(
decompressed_path
))
return
decompressed_path
def
_init_from_path
(
self
,
am
:
str
=
'fastspeech2_csmsc'
,
am_model
:
Optional
[
os
.
PathLike
]
=
None
,
am_params
:
Optional
[
os
.
PathLike
]
=
None
,
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
tones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
voc
:
str
=
'pwgan_csmsc'
,
voc_model
:
Optional
[
os
.
PathLike
]
=
None
,
voc_params
:
Optional
[
os
.
PathLike
]
=
None
,
lang
:
str
=
'zh'
,
am_predictor_conf
:
dict
=
None
,
voc_predictor_conf
:
dict
=
None
,
):
"""
Init model and other resources from a specific path.
"""
if
hasattr
(
self
,
'am'
)
and
hasattr
(
self
,
'voc'
):
logger
.
info
(
'Models had been initialized.'
)
return
# am
am_tag
=
am
+
'-'
+
lang
if
phones_dict
is
None
:
print
(
"please input phones_dict!"
)
### 后续下载的模型里加上 phone 和 tone的 dict 就不用这个了
#if am_model is None or am_params is None or phones_dict is None:
if
am_model
is
None
or
am_params
is
None
:
am_res_path
=
self
.
_get_pretrained_path
(
am_tag
)
self
.
am_res_path
=
am_res_path
self
.
am_model
=
os
.
path
.
join
(
am_res_path
,
pretrained_models
[
am_tag
][
'model'
])
self
.
am_params
=
os
.
path
.
join
(
am_res_path
,
pretrained_models
[
am_tag
][
'params'
])
# must have phones_dict in acoustic
#self.phones_dict = os.path.join(
#am_res_path, pretrained_models[am_tag]['phones_dict'])
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
logger
.
info
(
am_res_path
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
else
:
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
self
.
am_params
=
os
.
path
.
abspath
(
am_params
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
am_model
))
print
(
"self.phones_dict:"
,
self
.
phones_dict
)
# for speedyspeech
self
.
tones_dict
=
None
if
'tones_dict'
in
pretrained_models
[
am_tag
]:
self
.
tones_dict
=
os
.
path
.
join
(
am_res_path
,
pretrained_models
[
am_tag
][
'tones_dict'
])
if
tones_dict
:
self
.
tones_dict
=
tones_dict
# for multi speaker fastspeech2
self
.
speaker_dict
=
None
if
'speaker_dict'
in
pretrained_models
[
am_tag
]:
self
.
speaker_dict
=
os
.
path
.
join
(
am_res_path
,
pretrained_models
[
am_tag
][
'speaker_dict'
])
if
speaker_dict
:
self
.
speaker_dict
=
speaker_dict
# voc
voc_tag
=
voc
+
'-'
+
lang
if
voc_model
is
None
or
voc_params
is
None
:
voc_res_path
=
self
.
_get_pretrained_path
(
voc_tag
)
self
.
voc_res_path
=
voc_res_path
self
.
voc_model
=
os
.
path
.
join
(
voc_res_path
,
pretrained_models
[
voc_tag
][
'model'
])
self
.
voc_params
=
os
.
path
.
join
(
voc_res_path
,
pretrained_models
[
voc_tag
][
'params'
])
logger
.
info
(
voc_res_path
)
logger
.
info
(
self
.
voc_model
)
logger
.
info
(
self
.
voc_params
)
else
:
self
.
voc_model
=
os
.
path
.
abspath
(
voc_model
)
self
.
voc_params
=
os
.
path
.
abspath
(
voc_params
)
self
.
voc_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
voc_model
))
# Init body.
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
vocab_size
=
len
(
phn_id
)
print
(
"vocab_size:"
,
vocab_size
)
tone_size
=
None
if
self
.
tones_dict
:
with
open
(
self
.
tones_dict
,
"r"
)
as
f
:
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_size
=
len
(
tone_id
)
print
(
"tone_size:"
,
tone_size
)
spk_num
=
None
if
self
.
speaker_dict
:
with
open
(
self
.
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_num
=
len
(
spk_id
)
print
(
"spk_num:"
,
spk_num
)
# frontend
if
lang
==
'zh'
:
self
.
frontend
=
Frontend
(
phone_vocab_path
=
self
.
phones_dict
,
tone_vocab_path
=
self
.
tones_dict
)
elif
lang
==
'en'
:
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
print
(
"frontend done!"
)
# am predictor
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
# voc predictor
self
.
voc_predictor_conf
=
voc_predictor_conf
self
.
voc_predictor
=
init_predictor
(
model_file
=
self
.
voc_model
,
params_file
=
self
.
voc_params
,
predictor_conf
=
self
.
voc_predictor_conf
)
@
paddle
.
no_grad
()
def
infer
(
self
,
text
:
str
,
lang
:
str
=
'zh'
,
am
:
str
=
'fastspeech2_csmsc'
,
spk_id
:
int
=
0
):
"""
Model inference and result stored in self.output.
"""
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
get_tone_ids
=
False
merge_sentences
=
False
if
am_name
==
'speedyspeech'
:
get_tone_ids
=
True
if
lang
==
'zh'
:
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
phone_ids
=
input_ids
[
"phone_ids"
]
if
get_tone_ids
:
tone_ids
=
input_ids
[
"tone_ids"
]
elif
lang
==
'en'
:
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
print
(
"lang should in {'zh', 'en'}!"
)
flags
=
0
for
i
in
range
(
len
(
phone_ids
)):
part_phone_ids
=
phone_ids
[
i
]
# am
if
am_name
==
'speedyspeech'
:
part_tone_ids
=
tone_ids
[
i
]
am_result
=
run_model
(
self
.
am_predictor
,
[
part_phone_ids
.
numpy
(),
part_tone_ids
.
numpy
()])
mel
=
am_result
[
0
]
# fastspeech2
else
:
# multi speaker do not have static model
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
pass
else
:
am_result
=
run_model
(
self
.
am_predictor
,
[
part_phone_ids
.
numpy
()])
mel
=
am_result
[
0
]
# voc
voc_result
=
run_model
(
self
.
voc_predictor
,
[
mel
])
wav
=
voc_result
[
0
]
wav
=
paddle
.
to_tensor
(
wav
)
if
flags
==
0
:
wav_all
=
wav
flags
=
1
else
:
wav_all
=
paddle
.
concat
([
wav_all
,
wav
])
self
.
_outputs
[
'wav'
]
=
wav_all
class
TTSEngine
(
BaseEngine
):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def
__init__
(
self
,
name
=
None
):
"""Initialize TTS server engine
"""
super
(
TTSEngine
,
self
).
__init__
()
self
.
executor
=
TTSServerExecutor
()
config_path
=
self
.
executor
.
parser
.
parse_args
().
conf
with
open
(
config_path
,
'rt'
)
as
f
:
self
.
conf_dict
=
yaml
.
safe_load
(
f
)
self
.
executor
.
_init_from_path
(
am
=
self
.
conf_dict
[
"am"
],
am_model
=
self
.
conf_dict
[
"am_model"
],
am_params
=
self
.
conf_dict
[
"am_params"
],
phones_dict
=
self
.
conf_dict
[
"phones_dict"
],
tones_dict
=
self
.
conf_dict
[
"tones_dict"
],
speaker_dict
=
self
.
conf_dict
[
"speaker_dict"
],
voc
=
self
.
conf_dict
[
"voc"
],
voc_model
=
self
.
conf_dict
[
"voc_model"
],
voc_params
=
self
.
conf_dict
[
"voc_params"
],
lang
=
self
.
conf_dict
[
"lang"
],
am_predictor_conf
=
self
.
conf_dict
[
"am_predictor_conf"
],
voc_predictor_conf
=
self
.
conf_dict
[
"voc_predictor_conf"
],
)
logger
.
info
(
"Initialize TTS server engine successfully."
)
def
postprocess
(
self
,
wav
,
original_fs
:
int
,
target_fs
:
int
=
16000
,
volume
:
float
=
1.0
,
speed
:
float
=
1.0
,
audio_path
:
str
=
None
):
"""Post-processing operations, including speech, volume, sample rate, save audio file
Args:
wav (numpy(float)): Synthesized audio sample points
original_fs (int): original audio sample rate
target_fs (int): target audio sample rate
volume (float): target volume
speed (float): target speed
"""
# transform sample_rate
if
target_fs
==
0
or
target_fs
>
original_fs
:
target_fs
=
original_fs
wav_tar_fs
=
wav
else
:
wav_tar_fs
=
librosa
.
resample
(
np
.
squeeze
(
wav
),
original_fs
,
target_fs
)
# transform volume
wav_vol
=
wav_tar_fs
*
volume
# transform speed
try
:
# windows not support soxbindings
wav_speed
=
change_speed
(
wav_vol
,
speed
,
target_fs
)
except
:
raise
ServerBaseException
(
ErrorCode
.
SERVER_INTERNAL_ERR
,
"Can not install soxbindings on your system."
)
# wav to base64
buf
=
io
.
BytesIO
()
wavfile
.
write
(
buf
,
target_fs
,
wav_speed
)
base64_bytes
=
base64
.
b64encode
(
buf
.
read
())
wav_base64
=
base64_bytes
.
decode
(
'utf-8'
)
# save audio
if
audio_path
is
not
None
and
audio_path
.
endswith
(
".wav"
):
sf
.
write
(
audio_path
,
wav_speed
,
target_fs
)
elif
audio_path
is
not
None
and
audio_path
.
endswith
(
".pcm"
):
wav_norm
=
wav_speed
*
(
32767
/
max
(
0.001
,
np
.
max
(
np
.
abs
(
wav_speed
))))
with
open
(
audio_path
,
"wb"
)
as
f
:
f
.
write
(
wav_norm
.
astype
(
np
.
int16
))
return
target_fs
,
wav_base64
def
run
(
self
,
sentence
:
str
,
spk_id
:
int
=
0
,
speed
:
float
=
1.0
,
volume
:
float
=
1.0
,
sample_rate
:
int
=
0
,
save_path
:
str
=
None
):
"""get the result of the server response
Args:
sentence (str): sentence to be synthesized
spk_id (int, optional): speaker id. Defaults to 0.
speed (float, optional): audio speed, 0 < speed <=3.0. Defaults to 1.0.
volume (float, optional): The volume relative to the audio synthesized by the model,
0 < volume <=3.0. Defaults to 1.0.
sample_rate (int, optional): Set the sample rate of the synthesized audio.
0 represents the sample rate for model synthesis. Defaults to 0.
save_path (str, optional): The save path of the synthesized audio. Defaults to None.
Raises:
ServerBaseException: Exception
ServerBaseException: Exception
Returns:
lang, target_sample_rate, wav_base64
"""
lang
=
self
.
conf_dict
[
"lang"
]
try
:
self
.
executor
.
infer
(
text
=
sentence
,
lang
=
lang
,
am
=
self
.
conf_dict
[
"am"
],
spk_id
=
spk_id
)
except
:
raise
ServerBaseException
(
ErrorCode
.
SERVER_INTERNAL_ERR
,
"tts infer failed."
)
try
:
target_sample_rate
,
wav_base64
=
self
.
postprocess
(
wav
=
self
.
executor
.
_outputs
[
'wav'
].
numpy
(),
#original_fs=self.executor.am_config.fs,
original_fs
=
24000
,
# TODO get sample rate from model
target_fs
=
sample_rate
,
volume
=
volume
,
speed
=
speed
,
audio_path
=
save_path
)
except
:
raise
ServerBaseException
(
ErrorCode
.
SERVER_INTERNAL_ERR
,
"tts postprocess failed."
)
return
lang
,
target_sample_rate
,
wav_base64
speechserving/speechserving/utils/audio_process.py
浏览文件 @
14febbe5
...
@@ -17,11 +17,11 @@ import numpy as np
...
@@ -17,11 +17,11 @@ import numpy as np
def
wav2pcm
(
wavfile
,
pcmfile
,
data_type
=
np
.
int16
):
def
wav2pcm
(
wavfile
,
pcmfile
,
data_type
=
np
.
int16
):
f
=
open
(
wavfile
,
"rb"
)
with
open
(
wavfile
,
"rb"
)
as
f
:
f
.
seek
(
0
)
f
.
seek
(
0
)
f
.
read
(
44
)
f
.
read
(
44
)
data
=
np
.
fromfile
(
f
,
dtype
=
data_type
)
data
=
np
.
fromfile
(
f
,
dtype
=
data_type
)
data
.
tofile
(
pcmfile
)
data
.
tofile
(
pcmfile
)
def
pcm2wav
(
pcm_file
,
wav_file
,
channels
=
1
,
bits
=
16
,
sample_rate
=
16000
):
def
pcm2wav
(
pcm_file
,
wav_file
,
channels
=
1
,
bits
=
16
,
sample_rate
=
16000
):
...
@@ -52,7 +52,7 @@ def change_speed(sample_raw, speed_rate, sample_rate):
...
@@ -52,7 +52,7 @@ def change_speed(sample_raw, speed_rate, sample_rate):
:raises ValueError: If speed_rate <= 0.0.
:raises ValueError: If speed_rate <= 0.0.
"""
"""
if
speed_rate
==
1.0
:
if
speed_rate
==
1.0
:
return
return
sample_raw
if
speed_rate
<=
0
:
if
speed_rate
<=
0
:
raise
ValueError
(
"speed_rate should be greater than zero."
)
raise
ValueError
(
"speed_rate should be greater than zero."
)
...
...
speechserving/speechserving/utils/paddle_predictor.py
0 → 100644
浏览文件 @
14febbe5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Optional
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
def
init_predictor
(
model_dir
:
Optional
[
os
.
PathLike
]
=
None
,
model_file
:
Optional
[
os
.
PathLike
]
=
None
,
params_file
:
Optional
[
os
.
PathLike
]
=
None
,
predictor_conf
:
dict
=
None
):
"""Create predictor with Paddle inference
Args:
model_dir (Optional[os.PathLike], optional): The path of the static model saved in the model layer. Defaults to None.
model_file (Optional[os.PathLike], optional): *.pdmodel file path. Defaults to None.
params_file (Optional[os.PathLike], optional): *.pdiparams file path.. Defaults to None.
predictor_conf (dict, optional): The configuration parameters of predictor. Defaults to None.
Returns:
[type]: [description]
"""
if
model_dir
is
not
None
:
config
=
Config
(
args
.
model_dir
)
else
:
config
=
Config
(
model_file
,
params_file
)
config
.
enable_memory_optim
()
if
"use_gpu"
in
predictor_conf
and
predictor_conf
[
"use_gpu"
]
==
"true"
:
config
.
enable_use_gpu
(
1000
,
0
)
if
"enable_mkldnn"
in
predictor_conf
and
predictor_conf
[
"enable_mkldnn"
]
==
"true"
:
config
.
enable_mkldnn
()
if
"switch_ir_optim"
in
predictor_conf
and
predictor_conf
[
"switch_ir_optim"
]
==
"true"
:
config
.
switch_ir_optim
()
predictor
=
create_predictor
(
config
)
return
predictor
def
run_model
(
predictor
,
input
:
list
):
""" run predictor
Args:
predictor: paddle inference predictor
input (list): The input of predictor
Returns:
list: result list
"""
input_names
=
predictor
.
get_input_names
()
for
i
,
name
in
enumerate
(
input_names
):
input_handle
=
predictor
.
get_input_handle
(
name
)
input_handle
.
copy_from_cpu
(
input
[
i
])
# do the inference
predictor
.
run
()
results
=
[]
# get out data from output tensor
output_names
=
predictor
.
get_output_names
()
for
i
,
name
in
enumerate
(
output_names
):
output_handle
=
predictor
.
get_output_handle
(
name
)
output_data
=
output_handle
.
copy_to_cpu
()
results
.
append
(
output_data
)
return
results
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录