Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b9ade180
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b9ade180
编写于
8月 03, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add onnxruntime infer for cli
上级
070a08f2
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
450 addition
and
281 deletion
+450
-281
examples/aishell3/tts3/run.sh
examples/aishell3/tts3/run.sh
+1
-1
examples/ljspeech/tts3/run.sh
examples/ljspeech/tts3/run.sh
+1
-1
examples/vctk/tts3/run.sh
examples/vctk/tts3/run.sh
+1
-1
paddlespeech/cli/tts/infer.py
paddlespeech/cli/tts/infer.py
+254
-96
paddlespeech/resource/pretrained_models.py
paddlespeech/resource/pretrained_models.py
+1
-1
paddlespeech/t2s/exps/inference.py
paddlespeech/t2s/exps/inference.py
+0
-5
paddlespeech/t2s/exps/inference_streaming.py
paddlespeech/t2s/exps/inference_streaming.py
+8
-8
paddlespeech/t2s/exps/ort_predict.py
paddlespeech/t2s/exps/ort_predict.py
+6
-6
paddlespeech/t2s/exps/ort_predict_e2e.py
paddlespeech/t2s/exps/ort_predict_e2e.py
+46
-45
paddlespeech/t2s/exps/ort_predict_streaming.py
paddlespeech/t2s/exps/ort_predict_streaming.py
+24
-20
paddlespeech/t2s/exps/syn_utils.py
paddlespeech/t2s/exps/syn_utils.py
+62
-48
paddlespeech/t2s/exps/synthesize_e2e.py
paddlespeech/t2s/exps/synthesize_e2e.py
+12
-21
paddlespeech/t2s/exps/synthesize_streaming.py
paddlespeech/t2s/exps/synthesize_streaming.py
+8
-9
paddlespeech/t2s/frontend/mix_frontend.py
paddlespeech/t2s/frontend/mix_frontend.py
+5
-3
paddlespeech/t2s/frontend/phonectic.py
paddlespeech/t2s/frontend/phonectic.py
+6
-3
paddlespeech/t2s/frontend/zh_frontend.py
paddlespeech/t2s/frontend/zh_frontend.py
+15
-13
未找到文件。
examples/aishell3/tts3/run.sh
浏览文件 @
b9ade180
...
@@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
...
@@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi
fi
# inference with onnxruntime, use fastspeech2 +
hifi
gan by default
# inference with onnxruntime, use fastspeech2 +
pw
gan by default
if
[
${
stage
}
-le
6
]
&&
[
${
stop_stage
}
-ge
6
]
;
then
if
[
${
stage
}
-le
6
]
&&
[
${
stop_stage
}
-ge
6
]
;
then
./local/ort_predict.sh
${
train_output_path
}
./local/ort_predict.sh
${
train_output_path
}
fi
fi
examples/ljspeech/tts3/run.sh
浏览文件 @
b9ade180
...
@@ -55,7 +55,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
...
@@ -55,7 +55,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_ljspeech
# ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_ljspeech
fi
fi
# inference with onnxruntime, use fastspeech2 +
hifi
gan by default
# inference with onnxruntime, use fastspeech2 +
pw
gan by default
if
[
${
stage
}
-le
6
]
&&
[
${
stop_stage
}
-ge
6
]
;
then
if
[
${
stage
}
-le
6
]
&&
[
${
stop_stage
}
-ge
6
]
;
then
./local/ort_predict.sh
${
train_output_path
}
./local/ort_predict.sh
${
train_output_path
}
fi
fi
examples/vctk/tts3/run.sh
浏览文件 @
b9ade180
...
@@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
...
@@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi
fi
# inference with onnxruntime, use fastspeech2 +
hifi
gan by default
# inference with onnxruntime, use fastspeech2 +
pw
gan by default
if
[
${
stage
}
-le
6
]
&&
[
${
stop_stage
}
-ge
6
]
;
then
if
[
${
stage
}
-le
6
]
&&
[
${
stop_stage
}
-ge
6
]
;
then
./local/ort_predict.sh
${
train_output_path
}
./local/ort_predict.sh
${
train_output_path
}
fi
fi
paddlespeech/cli/tts/infer.py
浏览文件 @
b9ade180
...
@@ -29,10 +29,21 @@ from yacs.config import CfgNode
...
@@ -29,10 +29,21 @@ from yacs.config import CfgNode
from
..executor
import
BaseExecutor
from
..executor
import
BaseExecutor
from
..log
import
logger
from
..log
import
logger
from
..utils
import
stats_wrapper
from
..utils
import
stats_wrapper
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.t2s.exps.syn_utils
import
get_am_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.modules.normalizer
import
ZScore
from
paddlespeech.t2s.exps.syn_utils
import
get_sess
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
from
paddlespeech.t2s.exps.syn_utils
import
run_frontend
from
paddlespeech.t2s.utils
import
str2bool
__all__
=
[
'TTSExecutor'
]
__all__
=
[
'TTSExecutor'
]
ONNX_SUPPORT_SET
=
{
'speedyspeech_csmsc'
,
'fastspeech2_csmsc'
,
'fastspeech2_ljspeech'
,
'fastspeech2_aishell3'
,
'fastspeech2_vctk'
,
'pwgan_csmsc'
,
'pwgan_ljspeech'
,
'pwgan_aishell3'
,
'pwgan_vctk'
,
'mb_melgan_csmsc'
,
'hifigan_csmsc'
,
'hifigan_ljspeech'
,
'hifigan_aishell3'
,
'hifigan_vctk'
}
class
TTSExecutor
(
BaseExecutor
):
class
TTSExecutor
(
BaseExecutor
):
...
@@ -142,6 +153,8 @@ class TTSExecutor(BaseExecutor):
...
@@ -142,6 +153,8 @@ class TTSExecutor(BaseExecutor):
default
=
paddle
.
get_device
(),
default
=
paddle
.
get_device
(),
help
=
'Choose device to execute model inference.'
)
help
=
'Choose device to execute model inference.'
)
self
.
parser
.
add_argument
(
'--cpu_threads'
,
type
=
int
,
default
=
2
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output.wav'
,
help
=
'output file name'
)
'--output'
,
type
=
str
,
default
=
'output.wav'
,
help
=
'output file name'
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
...
@@ -154,6 +167,11 @@ class TTSExecutor(BaseExecutor):
...
@@ -154,6 +167,11 @@ class TTSExecutor(BaseExecutor):
'--verbose'
,
'--verbose'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Increase logger verbosity of current task.'
)
help
=
'Increase logger verbosity of current task.'
)
self
.
parser
.
add_argument
(
"--use_onnx"
,
type
=
str2bool
,
default
=
False
,
help
=
"whether to usen onnxruntime inference."
)
def
_init_from_path
(
def
_init_from_path
(
self
,
self
,
...
@@ -164,7 +182,7 @@ class TTSExecutor(BaseExecutor):
...
@@ -164,7 +182,7 @@ class TTSExecutor(BaseExecutor):
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
tones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
tones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
voc
:
str
=
'
pw
gan_csmsc'
,
voc
:
str
=
'
hifi
gan_csmsc'
,
voc_config
:
Optional
[
os
.
PathLike
]
=
None
,
voc_config
:
Optional
[
os
.
PathLike
]
=
None
,
voc_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
voc_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
voc_stat
:
Optional
[
os
.
PathLike
]
=
None
,
voc_stat
:
Optional
[
os
.
PathLike
]
=
None
,
...
@@ -288,58 +306,111 @@ class TTSExecutor(BaseExecutor):
...
@@ -288,58 +306,111 @@ class TTSExecutor(BaseExecutor):
lang
=
lang
,
phones_dict
=
self
.
phones_dict
,
tones_dict
=
self
.
tones_dict
)
lang
=
lang
,
phones_dict
=
self
.
phones_dict
,
tones_dict
=
self
.
tones_dict
)
# acoustic model
# acoustic model
odim
=
self
.
am_config
.
n_mels
self
.
am_inference
=
get_am_inference
(
# model: {model_name}_{dataset}
am
=
am
,
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_config
=
self
.
am_config
,
am_ckpt
=
self
.
am_ckpt
,
am_class
=
self
.
task_resource
.
get_model_class
(
am_name
)
am_stat
=
self
.
am_stat
,
am_inference_class
=
self
.
task_resource
.
get_model_class
(
am_name
+
phones_dict
=
self
.
phones_dict
,
'_inference'
)
tones_dict
=
self
.
tones_dict
,
speaker_dict
=
self
.
speaker_dict
)
if
am_name
==
'fastspeech2'
:
am
=
am_class
(
idim
=
vocab_size
,
odim
=
odim
,
spk_num
=
spk_num
,
**
self
.
am_config
[
"model"
])
elif
am_name
==
'speedyspeech'
:
am
=
am_class
(
vocab_size
=
vocab_size
,
tone_size
=
tone_size
,
**
self
.
am_config
[
"model"
])
elif
am_name
==
'tacotron2'
:
am
=
am_class
(
idim
=
vocab_size
,
odim
=
odim
,
**
self
.
am_config
[
"model"
])
am
.
set_state_dict
(
paddle
.
load
(
self
.
am_ckpt
)[
"main_params"
])
am
.
eval
()
am_mu
,
am_std
=
np
.
load
(
self
.
am_stat
)
am_mu
=
paddle
.
to_tensor
(
am_mu
)
am_std
=
paddle
.
to_tensor
(
am_std
)
am_normalizer
=
ZScore
(
am_mu
,
am_std
)
self
.
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
self
.
am_inference
.
eval
()
# vocoder
# vocoder
# model: {model_name}_{dataset}
self
.
voc_inference
=
get_voc_inference
(
voc_name
=
voc
[:
voc
.
rindex
(
'_'
)]
voc
=
voc
,
voc_class
=
self
.
task_resource
.
get_model_class
(
voc_name
)
voc_config
=
self
.
voc_config
,
voc_inference_class
=
self
.
task_resource
.
get_model_class
(
voc_name
+
voc_ckpt
=
self
.
voc_ckpt
,
'_inference'
)
voc_stat
=
self
.
voc_stat
)
if
voc_name
!=
'wavernn'
:
voc
=
voc_class
(
**
self
.
voc_config
[
"generator_params"
])
def
_init_from_path_onnx
(
self
,
voc
.
set_state_dict
(
paddle
.
load
(
self
.
voc_ckpt
)[
"generator_params"
])
am
:
str
=
'fastspeech2_csmsc'
,
voc
.
remove_weight_norm
()
am_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
voc
.
eval
()
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
tones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
voc
:
str
=
'hifigan_csmsc'
,
voc_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
lang
:
str
=
'zh'
,
device
:
str
=
'cpu'
,
cpu_threads
:
int
=
2
,
fs
:
int
=
24000
):
if
hasattr
(
self
,
'am_sess'
)
and
hasattr
(
self
,
'voc_sess'
):
logger
.
debug
(
'Models had been initialized.'
)
return
# am
if
am_ckpt
is
None
or
phones_dict
is
None
:
use_pretrained_am
=
True
else
:
use_pretrained_am
=
False
am_tag
=
am
+
'_onnx'
+
'-'
+
lang
self
.
task_resource
.
set_task_model
(
model_tag
=
am_tag
,
model_type
=
0
,
# am
skip_download
=
not
use_pretrained_am
,
version
=
None
,
# default version
)
if
use_pretrained_am
:
self
.
am_res_path
=
self
.
task_resource
.
res_dir
self
.
am_ckpt
=
os
.
path
.
join
(
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'ckpt'
][
0
])
# must have phones_dict in acoustic
self
.
phones_dict
=
os
.
path
.
join
(
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
self
.
am_fs
=
self
.
task_resource
.
res_dict
[
'sample_rate'
]
logger
.
debug
(
self
.
am_res_path
)
logger
.
debug
(
self
.
am_ckpt
)
else
:
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
[
0
])
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
am_ckpt
))
self
.
am_fs
=
fs
# for speedyspeech
self
.
tones_dict
=
None
if
'tones_dict'
in
self
.
task_resource
.
res_dict
:
self
.
tones_dict
=
os
.
path
.
join
(
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'tones_dict'
])
if
tones_dict
:
self
.
tones_dict
=
tones_dict
# voc
if
voc_ckpt
is
None
:
use_pretrained_voc
=
True
else
:
use_pretrained_voc
=
False
voc_lang
=
lang
# we must use ljspeech's voc for mix am now!
if
lang
==
'mix'
:
voc_lang
=
'en'
voc_tag
=
voc
+
'_onnx'
+
'-'
+
voc_lang
self
.
task_resource
.
set_task_model
(
model_tag
=
voc_tag
,
model_type
=
1
,
# vocoder
skip_download
=
not
use_pretrained_voc
,
version
=
None
,
# default version
)
if
use_pretrained_voc
:
self
.
voc_res_path
=
self
.
task_resource
.
voc_res_dir
self
.
voc_ckpt
=
os
.
path
.
join
(
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'ckpt'
])
logger
.
debug
(
self
.
voc_res_path
)
logger
.
debug
(
self
.
voc_ckpt
)
else
:
else
:
voc
=
voc_class
(
**
self
.
voc_config
[
"model"
])
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
voc
.
set_state_dict
(
paddle
.
load
(
self
.
voc_ckpt
)[
"main_params"
])
self
.
voc_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
voc_ckpt
))
voc
.
eval
()
voc_mu
,
voc_std
=
np
.
load
(
self
.
voc_stat
)
# frontend
voc_mu
=
paddle
.
to_tensor
(
voc_mu
)
self
.
frontend
=
get_frontend
(
voc_std
=
paddle
.
to_tensor
(
voc_std
)
lang
=
lang
,
phones_dict
=
self
.
phones_dict
,
tones_dict
=
self
.
tones_dict
)
voc_normalizer
=
ZScore
(
voc_mu
,
voc_std
)
self
.
voc_inference
=
voc_inference_class
(
voc_normalizer
,
voc
)
self
.
am_sess
=
get_sess
(
self
.
voc_inference
.
eval
()
model_path
=
self
.
am_ckpt
,
device
=
device
,
cpu_threads
=
cpu_threads
)
# vocoder
self
.
voc_sess
=
get_sess
(
model_path
=
self
.
voc_ckpt
,
device
=
device
,
cpu_threads
=
cpu_threads
)
def
preprocess
(
self
,
input
:
Any
,
*
args
,
**
kwargs
):
def
preprocess
(
self
,
input
:
Any
,
*
args
,
**
kwargs
):
"""
"""
...
@@ -362,40 +433,28 @@ class TTSExecutor(BaseExecutor):
...
@@ -362,40 +433,28 @@ class TTSExecutor(BaseExecutor):
"""
"""
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
get_tone_ids
=
False
merge_sentences
=
False
merge_sentences
=
False
frontend_st
=
time
.
time
()
get_tone_ids
=
False
if
am_name
==
'speedyspeech'
:
if
am_name
==
'speedyspeech'
:
get_tone_ids
=
True
get_tone_ids
=
True
if
lang
==
'zh'
:
frontend_st
=
time
.
time
()
input_ids
=
self
.
frontend
.
get_input_ids
(
frontend_dict
=
run_frontend
(
text
,
frontend
=
self
.
frontend
,
merge_sentences
=
merge_sentences
,
text
=
text
,
get_tone_ids
=
get_tone_ids
)
merge_sentences
=
merge_sentences
,
phone_ids
=
input_ids
[
"phone_ids"
]
get_tone_ids
=
get_tone_ids
,
if
get_tone_ids
:
lang
=
lang
)
tone_ids
=
input_ids
[
"tone_ids"
]
elif
lang
==
'en'
:
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
elif
lang
==
'mix'
:
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
logger
.
error
(
"lang should in {'zh', 'en', 'mix'}!"
)
self
.
frontend_time
=
time
.
time
()
-
frontend_st
self
.
frontend_time
=
time
.
time
()
-
frontend_st
self
.
am_time
=
0
self
.
am_time
=
0
self
.
voc_time
=
0
self
.
voc_time
=
0
flags
=
0
flags
=
0
phone_ids
=
frontend_dict
[
'phone_ids'
]
for
i
in
range
(
len
(
phone_ids
)):
for
i
in
range
(
len
(
phone_ids
)):
am_st
=
time
.
time
()
am_st
=
time
.
time
()
part_phone_ids
=
phone_ids
[
i
]
part_phone_ids
=
phone_ids
[
i
]
# am
# am
if
am_name
==
'speedyspeech'
:
if
am_name
==
'speedyspeech'
:
part_tone_ids
=
tone_ids
[
i
]
part_tone_ids
=
frontend_dict
[
'tone_ids'
]
[
i
]
mel
=
self
.
am_inference
(
part_phone_ids
,
part_tone_ids
)
mel
=
self
.
am_inference
(
part_phone_ids
,
part_tone_ids
)
# fastspeech2
# fastspeech2
else
:
else
:
...
@@ -417,6 +476,62 @@ class TTSExecutor(BaseExecutor):
...
@@ -417,6 +476,62 @@ class TTSExecutor(BaseExecutor):
self
.
voc_time
+=
(
time
.
time
()
-
voc_st
)
self
.
voc_time
+=
(
time
.
time
()
-
voc_st
)
self
.
_outputs
[
'wav'
]
=
wav_all
self
.
_outputs
[
'wav'
]
=
wav_all
def
infer_onnx
(
self
,
text
:
str
,
lang
:
str
=
'zh'
,
am
:
str
=
'fastspeech2_csmsc'
,
spk_id
:
int
=
0
):
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
merge_sentences
=
False
get_tone_ids
=
False
if
am_name
==
'speedyspeech'
:
get_tone_ids
=
True
am_input_feed
=
{}
frontend_st
=
time
.
time
()
frontend_dict
=
run_frontend
(
frontend
=
self
.
frontend
,
text
=
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
,
lang
=
lang
,
to_tensor
=
False
)
self
.
frontend_time
=
time
.
time
()
-
frontend_st
phone_ids
=
frontend_dict
[
'phone_ids'
]
self
.
am_time
=
0
self
.
voc_time
=
0
flags
=
0
for
i
in
range
(
len
(
phone_ids
)):
am_st
=
time
.
time
()
part_phone_ids
=
phone_ids
[
i
]
if
am_name
==
'fastspeech2'
:
am_input_feed
.
update
({
'text'
:
part_phone_ids
})
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
# NOTE: 'spk_id' should be List[int] rather than int here!!
am_input_feed
.
update
({
'spk_id'
:
[
spk_id
]})
elif
am_name
==
'speedyspeech'
:
part_tone_ids
=
frontend_dict
[
'tone_ids'
][
i
]
am_input_feed
.
update
({
'phones'
:
part_phone_ids
,
'tones'
:
part_tone_ids
})
mel
=
self
.
am_sess
.
run
(
output_names
=
None
,
input_feed
=
am_input_feed
)
mel
=
mel
[
0
]
self
.
am_time
+=
(
time
.
time
()
-
am_st
)
# voc
voc_st
=
time
.
time
()
wav
=
self
.
voc_sess
.
run
(
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel
})
wav
=
wav
[
0
]
if
flags
==
0
:
wav_all
=
wav
flags
=
1
else
:
wav_all
=
np
.
concatenate
([
wav_all
,
wav
])
self
.
voc_time
+=
(
time
.
time
()
-
voc_st
)
self
.
_outputs
[
'wav'
]
=
wav_all
def
postprocess
(
self
,
output
:
str
=
'output.wav'
)
->
Union
[
str
,
os
.
PathLike
]:
def
postprocess
(
self
,
output
:
str
=
'output.wav'
)
->
Union
[
str
,
os
.
PathLike
]:
"""
"""
Output postprocess and return results.
Output postprocess and return results.
...
@@ -430,6 +545,20 @@ class TTSExecutor(BaseExecutor):
...
@@ -430,6 +545,20 @@ class TTSExecutor(BaseExecutor):
output
,
self
.
_outputs
[
'wav'
].
numpy
(),
samplerate
=
self
.
am_config
.
fs
)
output
,
self
.
_outputs
[
'wav'
].
numpy
(),
samplerate
=
self
.
am_config
.
fs
)
return
output
return
output
def
postprocess_onnx
(
self
,
output
:
str
=
'output.wav'
)
->
Union
[
str
,
os
.
PathLike
]:
"""
Output postprocess and return results.
This method get model output from self._outputs and convert it into human-readable results.
Returns:
Union[str, os.PathLike]: Human-readable results such as texts and audio files.
"""
output
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
output
))
sf
.
write
(
output
,
self
.
_outputs
[
'wav'
],
samplerate
=
self
.
am_fs
)
return
output
# 命令行的入口是这里
def
execute
(
self
,
argv
:
List
[
str
])
->
bool
:
def
execute
(
self
,
argv
:
List
[
str
])
->
bool
:
"""
"""
Command line entry.
Command line entry.
...
@@ -451,6 +580,8 @@ class TTSExecutor(BaseExecutor):
...
@@ -451,6 +580,8 @@ class TTSExecutor(BaseExecutor):
lang
=
args
.
lang
lang
=
args
.
lang
device
=
args
.
device
device
=
args
.
device
spk_id
=
args
.
spk_id
spk_id
=
args
.
spk_id
use_onnx
=
args
.
use_onnx
cpu_threads
=
args
.
cpu_threads
if
not
args
.
verbose
:
if
not
args
.
verbose
:
self
.
disable_task_loggers
()
self
.
disable_task_loggers
()
...
@@ -487,7 +618,9 @@ class TTSExecutor(BaseExecutor):
...
@@ -487,7 +618,9 @@ class TTSExecutor(BaseExecutor):
# other
# other
lang
=
lang
,
lang
=
lang
,
device
=
device
,
device
=
device
,
output
=
output
)
output
=
output
,
use_onnx
=
use_onnx
,
cpu_threads
=
cpu_threads
)
task_results
[
id_
]
=
res
task_results
[
id_
]
=
res
except
Exception
as
e
:
except
Exception
as
e
:
has_exceptions
=
True
has_exceptions
=
True
...
@@ -501,6 +634,7 @@ class TTSExecutor(BaseExecutor):
...
@@ -501,6 +634,7 @@ class TTSExecutor(BaseExecutor):
else
:
else
:
return
True
return
True
# pyton api 的入口是这里
@
stats_wrapper
@
stats_wrapper
def
__call__
(
self
,
def
__call__
(
self
,
text
:
str
,
text
:
str
,
...
@@ -512,33 +646,57 @@ class TTSExecutor(BaseExecutor):
...
@@ -512,33 +646,57 @@ class TTSExecutor(BaseExecutor):
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
tones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
tones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
voc
:
str
=
'
pw
gan_csmsc'
,
voc
:
str
=
'
hifi
gan_csmsc'
,
voc_config
:
Optional
[
os
.
PathLike
]
=
None
,
voc_config
:
Optional
[
os
.
PathLike
]
=
None
,
voc_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
voc_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
voc_stat
:
Optional
[
os
.
PathLike
]
=
None
,
voc_stat
:
Optional
[
os
.
PathLike
]
=
None
,
lang
:
str
=
'zh'
,
lang
:
str
=
'zh'
,
device
:
str
=
paddle
.
get_device
(),
device
:
str
=
paddle
.
get_device
(),
output
:
str
=
'output.wav'
):
output
:
str
=
'output.wav'
,
use_onnx
:
bool
=
False
,
cpu_threads
:
int
=
2
):
"""
"""
Python API to call an executor.
Python API to call an executor.
"""
"""
paddle
.
set_device
(
device
)
if
not
use_onnx
:
self
.
_init_from_path
(
paddle
.
set_device
(
device
)
am
=
am
,
self
.
_init_from_path
(
am_config
=
am_config
,
am
=
am
,
am_ckpt
=
am_ckpt
,
am_config
=
am_config
,
am_stat
=
am_stat
,
am_ckpt
=
am_ckpt
,
phones_dict
=
phones_dict
,
am_stat
=
am_stat
,
tones_dict
=
tones_dict
,
phones_dict
=
phones_dict
,
speaker_dict
=
speaker_dict
,
tones_dict
=
tones_dict
,
voc
=
voc
,
speaker_dict
=
speaker_dict
,
voc_config
=
voc_config
,
voc
=
voc
,
voc_ckpt
=
voc_ckpt
,
voc_config
=
voc_config
,
voc_stat
=
voc_stat
,
voc_ckpt
=
voc_ckpt
,
lang
=
lang
)
voc_stat
=
voc_stat
,
lang
=
lang
)
self
.
infer
(
text
=
text
,
lang
=
lang
,
am
=
am
,
spk_id
=
spk_id
)
self
.
infer
(
text
=
text
,
lang
=
lang
,
am
=
am
,
spk_id
=
spk_id
)
res
=
self
.
postprocess
(
output
=
output
)
res
=
self
.
postprocess
(
output
=
output
)
return
res
return
res
else
:
# use onnx
# we use `cpu` for onnxruntime by default
# please see description in https://github.com/PaddlePaddle/PaddleSpeech/pull/2220
self
.
task_resource
=
CommonTaskResource
(
task
=
'tts'
,
model_format
=
'onnx'
)
assert
(
am
in
ONNX_SUPPORT_SET
and
voc
in
ONNX_SUPPORT_SET
),
f
'the am and voc you choose, they should be in
{
ONNX_SUPPORT_SET
}
'
self
.
_init_from_path_onnx
(
am
=
am
,
am_ckpt
=
am_ckpt
,
phones_dict
=
phones_dict
,
tones_dict
=
tones_dict
,
speaker_dict
=
speaker_dict
,
voc
=
voc
,
voc_ckpt
=
voc_ckpt
,
lang
=
lang
,
device
=
device
,
cpu_threads
=
cpu_threads
)
self
.
infer_onnx
(
text
=
text
,
lang
=
lang
,
am
=
am
,
spk_id
=
spk_id
)
res
=
self
.
postprocess_onnx
(
output
=
output
)
return
res
paddlespeech/resource/pretrained_models.py
浏览文件 @
b9ade180
...
@@ -1149,7 +1149,7 @@ tts_onnx_pretrained_models = {
...
@@ -1149,7 +1149,7 @@ tts_onnx_pretrained_models = {
"fastspeech2_vctk_onnx-en"
:
{
"fastspeech2_vctk_onnx-en"
:
{
'1.0'
:
{
'1.0'
:
{
'url'
:
'url'
:
'h
h
ttps://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip'
,
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip'
,
'md5'
:
'md5'
:
'd9c3a9b02204a2070504dd99f5f959bf'
,
'd9c3a9b02204a2070504dd99f5f959bf'
,
'ckpt'
:
[
'fastspeech2_vctk.onnx'
],
'ckpt'
:
[
'fastspeech2_vctk.onnx'
],
...
...
paddlespeech/t2s/exps/inference.py
浏览文件 @
b9ade180
...
@@ -86,11 +86,6 @@ def parse_args():
...
@@ -86,11 +86,6 @@ def parse_args():
"--inference_dir"
,
type
=
str
,
help
=
"dir to save inference models"
)
"--inference_dir"
,
type
=
str
,
help
=
"dir to save inference models"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
help
=
"output dir"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
help
=
"output dir"
)
# inference
# inference
parser
.
add_argument
(
"--use_trt"
,
type
=
str2bool
,
default
=
False
,
help
=
"Whether to use inference engin TensorRT."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--int8"
,
"--int8"
,
type
=
str2bool
,
type
=
str2bool
,
...
...
paddlespeech/t2s/exps/inference_streaming.py
浏览文件 @
b9ade180
...
@@ -27,6 +27,7 @@ from paddlespeech.t2s.exps.syn_utils import get_predictor
...
@@ -27,6 +27,7 @@ from paddlespeech.t2s.exps.syn_utils import get_predictor
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_streaming_am_output
from
paddlespeech.t2s.exps.syn_utils
import
get_streaming_am_output
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_output
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_output
from
paddlespeech.t2s.exps.syn_utils
import
run_frontend
from
paddlespeech.t2s.utils
import
str2bool
from
paddlespeech.t2s.utils
import
str2bool
...
@@ -175,14 +176,13 @@ def main():
...
@@ -175,14 +176,13 @@ def main():
for
utt_id
,
sentence
in
sentences
:
for
utt_id
,
sentence
in
sentences
:
with
timer
()
as
t
:
with
timer
()
as
t
:
# frontend
# frontend
if
args
.
lang
==
'zh'
:
frontend_dict
=
run_frontend
(
input_ids
=
frontend
.
get_input_ids
(
frontend
=
frontend
,
sentence
,
text
=
sentence
,
merge_sentences
=
merge_sentences
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
get_tone_ids
=
get_tone_ids
,
phone_ids
=
input_ids
[
"phone_ids"
]
lang
=
args
.
lang
)
else
:
phone_ids
=
frontend_dict
[
'phone_ids'
]
print
(
"lang should be 'zh' here!"
)
phones
=
phone_ids
[
0
].
numpy
()
phones
=
phone_ids
[
0
].
numpy
()
# acoustic model
# acoustic model
orig_hs
=
get_am_sublayer_output
(
orig_hs
=
get_am_sublayer_output
(
...
...
paddlespeech/t2s/exps/ort_predict.py
浏览文件 @
b9ade180
...
@@ -41,17 +41,17 @@ def ort_predict(args):
...
@@ -41,17 +41,17 @@ def ort_predict(args):
# am
# am
am_sess
=
get_sess
(
am_sess
=
get_sess
(
model_dir
=
args
.
inference_dir
,
model_path
=
str
(
Path
(
args
.
inference_dir
)
/
(
args
.
am
+
'.onnx'
)),
model_file
=
args
.
am
+
".onnx"
,
device
=
args
.
device
,
device
=
args
.
device
,
cpu_threads
=
args
.
cpu_threads
)
cpu_threads
=
args
.
cpu_threads
,
use_trt
=
args
.
use_trt
)
# vocoder
# vocoder
voc_sess
=
get_sess
(
voc_sess
=
get_sess
(
model_dir
=
args
.
inference_dir
,
model_path
=
str
(
Path
(
args
.
inference_dir
)
/
(
args
.
voc
+
'.onnx'
)),
model_file
=
args
.
voc
+
".onnx"
,
device
=
args
.
device
,
device
=
args
.
device
,
cpu_threads
=
args
.
cpu_threads
)
cpu_threads
=
args
.
cpu_threads
,
use_trt
=
args
.
use_trt
)
# am warmup
# am warmup
for
T
in
[
27
,
38
,
54
]:
for
T
in
[
27
,
38
,
54
]:
...
...
paddlespeech/t2s/exps/ort_predict_e2e.py
浏览文件 @
b9ade180
...
@@ -22,6 +22,7 @@ from timer import timer
...
@@ -22,6 +22,7 @@ from timer import timer
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_sess
from
paddlespeech.t2s.exps.syn_utils
import
get_sess
from
paddlespeech.t2s.exps.syn_utils
import
run_frontend
from
paddlespeech.t2s.utils
import
str2bool
from
paddlespeech.t2s.utils
import
str2bool
...
@@ -42,17 +43,17 @@ def ort_predict(args):
...
@@ -42,17 +43,17 @@ def ort_predict(args):
fs
=
24000
if
am_dataset
!=
'ljspeech'
else
22050
fs
=
24000
if
am_dataset
!=
'ljspeech'
else
22050
am_sess
=
get_sess
(
am_sess
=
get_sess
(
model_dir
=
args
.
inference_dir
,
model_path
=
str
(
Path
(
args
.
inference_dir
)
/
(
args
.
am
+
'.onnx'
)),
model_file
=
args
.
am
+
".onnx"
,
device
=
args
.
device
,
device
=
args
.
device
,
cpu_threads
=
args
.
cpu_threads
)
cpu_threads
=
args
.
cpu_threads
,
use_trt
=
args
.
use_trt
)
# vocoder
# vocoder
voc_sess
=
get_sess
(
voc_sess
=
get_sess
(
model_dir
=
args
.
inference_dir
,
model_path
=
str
(
Path
(
args
.
inference_dir
)
/
(
args
.
voc
+
'.onnx'
)),
model_file
=
args
.
voc
+
".onnx"
,
device
=
args
.
device
,
device
=
args
.
device
,
cpu_threads
=
args
.
cpu_threads
)
cpu_threads
=
args
.
cpu_threads
,
use_trt
=
args
.
use_trt
)
merge_sentences
=
True
merge_sentences
=
True
...
@@ -78,7 +79,6 @@ def ort_predict(args):
...
@@ -78,7 +79,6 @@ def ort_predict(args):
am_input_feed
.
update
({
'text'
:
phone_ids
})
am_input_feed
.
update
({
'text'
:
phone_ids
})
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
am_input_feed
.
update
({
'spk_id'
:
spk_id
})
am_input_feed
.
update
({
'spk_id'
:
spk_id
})
elif
am_name
==
'speedyspeech'
:
elif
am_name
==
'speedyspeech'
:
phone_ids
=
np
.
random
.
randint
(
1
,
92
,
size
=
(
T
,
))
phone_ids
=
np
.
random
.
randint
(
1
,
92
,
size
=
(
T
,
))
tone_ids
=
np
.
random
.
randint
(
1
,
5
,
size
=
(
T
,
))
tone_ids
=
np
.
random
.
randint
(
1
,
5
,
size
=
(
T
,
))
...
@@ -93,50 +93,51 @@ def ort_predict(args):
...
@@ -93,50 +93,51 @@ def ort_predict(args):
N
=
0
N
=
0
T
=
0
T
=
0
merge_sentences
=
Tru
e
merge_sentences
=
Fals
e
get_tone_ids
=
False
get_tone_ids
=
False
am_input_feed
=
{}
if
am_name
==
'speedyspeech'
:
if
am_name
==
'speedyspeech'
:
get_tone_ids
=
True
get_tone_ids
=
True
am_input_feed
=
{}
for
utt_id
,
sentence
in
sentences
:
for
utt_id
,
sentence
in
sentences
:
with
timer
()
as
t
:
with
timer
()
as
t
:
if
args
.
lang
==
'zh'
:
frontend_dict
=
run_frontend
(
input_ids
=
frontend
.
get_input_ids
(
frontend
=
frontend
,
sentence
,
text
=
sentence
,
merge_sentences
=
merge_sentences
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
get_tone_ids
=
get_tone_ids
,
phone_ids
=
input_ids
[
"phone_ids"
]
lang
=
args
.
lang
)
if
get_tone_ids
:
phone_ids
=
frontend_dict
[
'phone_ids'
]
tone_ids
=
input_ids
[
"tone_ids"
]
flags
=
0
elif
args
.
lang
==
'en'
:
for
i
in
range
(
len
(
phone_ids
)):
input_ids
=
frontend
.
get_input_ids
(
part_phone_ids
=
phone_ids
[
i
].
numpy
()
sentence
,
merge_sentences
=
merge_sentences
)
if
am_name
==
'fastspeech2'
:
phone_ids
=
input_ids
[
"phone_ids"
]
am_input_feed
.
update
({
'text'
:
part_phone_ids
})
else
:
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
print
(
"lang should in {'zh', 'en'}!"
)
am_input_feed
.
update
({
'spk_id'
:
spk_id
})
# merge_sentences=True here, so we only use the first item of phone_ids
elif
am_name
==
'speedyspeech'
:
phone_ids
=
phone_ids
[
0
].
numpy
()
part_tone_ids
=
frontend_dict
[
'tone_ids'
][
i
].
numpy
()
if
am_name
==
'fastspeech2'
:
am_input_feed
.
update
({
am_input_feed
.
update
({
'text'
:
phone_ids
})
'phones'
:
part_phone_ids
,
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
'tones'
:
part_tone_ids
am_input_feed
.
update
({
'spk_id'
:
spk_id
})
})
elif
am_name
==
'speedyspeech'
:
mel
=
am_sess
.
run
(
output_names
=
None
,
input_feed
=
am_input_feed
)
tone_ids
=
tone_ids
[
0
].
numpy
()
mel
=
mel
[
0
]
am_input_feed
.
update
({
'phones'
:
phone_ids
,
'tones'
:
tone_ids
})
wav
=
voc_sess
.
run
(
mel
=
am_sess
.
run
(
output_names
=
None
,
input_feed
=
am_input_feed
)
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel
})
mel
=
mel
[
0
]
wav
=
wav
[
0
]
wav
=
voc_sess
.
run
(
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel
})
if
flags
==
0
:
wav_all
=
wav
N
+=
len
(
wav
[
0
])
flags
=
1
T
+=
t
.
elapse
else
:
speed
=
len
(
wav
[
0
])
/
t
.
elapse
wav_all
=
np
.
concatenate
([
wav_all
,
wav
])
rtf
=
fs
/
speed
wav
=
wav_all
sf
.
write
(
N
+=
len
(
wav
)
str
(
output_dir
/
(
utt_id
+
".wav"
)),
T
+=
t
.
elapse
np
.
array
(
wav
)[
0
],
speed
=
len
(
wav
)
/
t
.
elapse
samplerate
=
fs
)
rtf
=
fs
/
speed
sf
.
write
(
str
(
output_dir
/
(
utt_id
+
".wav"
)),
wav
,
samplerate
=
fs
)
print
(
print
(
f
"
{
utt_id
}
, mel:
{
mel
.
shape
}
, wave:
{
len
(
wav
[
0
]
)
}
, time:
{
t
.
elapse
}
s, Hz:
{
speed
}
, RTF:
{
rtf
}
."
f
"
{
utt_id
}
, mel:
{
mel
.
shape
}
, wave:
{
len
(
wav
)
}
, time:
{
t
.
elapse
}
s, Hz:
{
speed
}
, RTF:
{
rtf
}
."
)
)
print
(
f
"generation speed:
{
N
/
T
}
Hz, RTF:
{
fs
/
(
N
/
T
)
}
"
)
print
(
f
"generation speed:
{
N
/
T
}
Hz, RTF:
{
fs
/
(
N
/
T
)
}
"
)
...
...
paddlespeech/t2s/exps/ort_predict_streaming.py
浏览文件 @
b9ade180
...
@@ -24,6 +24,7 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
...
@@ -24,6 +24,7 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_sess
from
paddlespeech.t2s.exps.syn_utils
import
get_sess
from
paddlespeech.t2s.exps.syn_utils
import
run_frontend
from
paddlespeech.t2s.utils
import
str2bool
from
paddlespeech.t2s.utils
import
str2bool
...
@@ -45,29 +46,33 @@ def ort_predict(args):
...
@@ -45,29 +46,33 @@ def ort_predict(args):
# streaming acoustic model
# streaming acoustic model
am_encoder_infer_sess
=
get_sess
(
am_encoder_infer_sess
=
get_sess
(
model_dir
=
args
.
inference_dir
,
model_path
=
str
(
model_file
=
args
.
am
+
"_am_encoder_infer"
+
".onnx"
,
Path
(
args
.
inference_dir
)
/
(
args
.
am
+
'_am_encoder_infer'
+
'.onnx'
)),
device
=
args
.
device
,
device
=
args
.
device
,
cpu_threads
=
args
.
cpu_threads
)
cpu_threads
=
args
.
cpu_threads
,
use_trt
=
args
.
use_trt
)
am_decoder_sess
=
get_sess
(
am_decoder_sess
=
get_sess
(
model_
dir
=
args
.
inference_dir
,
model_
path
=
str
(
model_file
=
args
.
am
+
"_am_decoder"
+
".onnx"
,
Path
(
args
.
inference_dir
)
/
(
args
.
am
+
'_am_decoder'
+
'.onnx'
))
,
device
=
args
.
device
,
device
=
args
.
device
,
cpu_threads
=
args
.
cpu_threads
)
cpu_threads
=
args
.
cpu_threads
,
use_trt
=
args
.
use_trt
)
am_postnet_sess
=
get_sess
(
am_postnet_sess
=
get_sess
(
model_
dir
=
args
.
inference_dir
,
model_
path
=
str
(
model_file
=
args
.
am
+
"_am_postnet"
+
".onnx"
,
Path
(
args
.
inference_dir
)
/
(
args
.
am
+
'_am_postnet'
+
'.onnx'
))
,
device
=
args
.
device
,
device
=
args
.
device
,
cpu_threads
=
args
.
cpu_threads
)
cpu_threads
=
args
.
cpu_threads
,
use_trt
=
args
.
use_trt
)
am_mu
,
am_std
=
np
.
load
(
args
.
am_stat
)
am_mu
,
am_std
=
np
.
load
(
args
.
am_stat
)
# vocoder
# vocoder
voc_sess
=
get_sess
(
voc_sess
=
get_sess
(
model_dir
=
args
.
inference_dir
,
model_path
=
str
(
Path
(
args
.
inference_dir
)
/
(
args
.
voc
+
'.onnx'
)),
model_file
=
args
.
voc
+
".onnx"
,
device
=
args
.
device
,
device
=
args
.
device
,
cpu_threads
=
args
.
cpu_threads
)
cpu_threads
=
args
.
cpu_threads
,
use_trt
=
args
.
use_trt
)
# frontend warmup
# frontend warmup
# Loading model cost 0.5+ seconds
# Loading model cost 0.5+ seconds
...
@@ -102,14 +107,13 @@ def ort_predict(args):
...
@@ -102,14 +107,13 @@ def ort_predict(args):
for
utt_id
,
sentence
in
sentences
:
for
utt_id
,
sentence
in
sentences
:
with
timer
()
as
t
:
with
timer
()
as
t
:
if
args
.
lang
==
'zh'
:
frontend_dict
=
run_frontend
(
input_ids
=
frontend
.
get_input_ids
(
frontend
=
frontend
,
sentence
,
text
=
sentence
,
merge_sentences
=
merge_sentences
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
get_tone_ids
=
get_tone_ids
,
phone_ids
=
input_ids
[
"phone_ids"
]
lang
=
args
.
lang
)
else
:
phone_ids
=
frontend_dict
[
'phone_ids'
]
print
(
"lang should in be 'zh' here!"
)
# merge_sentences=True here, so we only use the first item of phone_ids
# merge_sentences=True here, so we only use the first item of phone_ids
phone_ids
=
phone_ids
[
0
].
numpy
()
phone_ids
=
phone_ids
[
0
].
numpy
()
orig_hs
=
am_encoder_infer_sess
.
run
(
orig_hs
=
am_encoder_infer_sess
.
run
(
...
...
paddlespeech/t2s/exps/syn_utils.py
浏览文件 @
b9ade180
...
@@ -33,6 +33,8 @@ from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
...
@@ -33,6 +33,8 @@ from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from
paddlespeech.t2s.frontend.zh_frontend
import
Frontend
from
paddlespeech.t2s.frontend.zh_frontend
import
Frontend
from
paddlespeech.t2s.modules.normalizer
import
ZScore
from
paddlespeech.t2s.modules.normalizer
import
ZScore
from
paddlespeech.utils.dynamic_import
import
dynamic_import
from
paddlespeech.utils.dynamic_import
import
dynamic_import
# remove [W:onnxruntime: xxx] from ort
ort
.
set_default_logger_severity
(
3
)
model_alias
=
{
model_alias
=
{
# acoustic model
# acoustic model
...
@@ -161,13 +163,42 @@ def get_frontend(lang: str='zh',
...
@@ -161,13 +163,42 @@ def get_frontend(lang: str='zh',
elif
lang
==
'mix'
:
elif
lang
==
'mix'
:
frontend
=
MixFrontend
(
frontend
=
MixFrontend
(
phone_vocab_path
=
phones_dict
,
tone_vocab_path
=
tones_dict
)
phone_vocab_path
=
phones_dict
,
tone_vocab_path
=
tones_dict
)
else
:
else
:
print
(
"wrong lang!"
)
print
(
"wrong lang!"
)
print
(
"frontend done!"
)
return
frontend
return
frontend
def
run_frontend
(
frontend
:
object
,
text
:
str
,
merge_sentences
:
bool
=
False
,
get_tone_ids
:
bool
=
False
,
lang
:
str
=
'zh'
,
to_tensor
:
bool
=
True
):
outs
=
dict
()
if
lang
==
'zh'
:
input_ids
=
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
,
to_tensor
=
to_tensor
)
phone_ids
=
input_ids
[
"phone_ids"
]
if
get_tone_ids
:
tone_ids
=
input_ids
[
"tone_ids"
]
outs
.
update
({
'tone_ids'
:
tone_ids
})
elif
lang
==
'en'
:
input_ids
=
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
,
to_tensor
=
to_tensor
)
phone_ids
=
input_ids
[
"phone_ids"
]
elif
lang
==
'mix'
:
input_ids
=
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
,
to_tensor
=
to_tensor
)
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
print
(
"lang should in {'zh', 'en', 'mix'}!"
)
outs
.
update
({
'phone_ids'
:
phone_ids
})
return
outs
# dygraph
# dygraph
def
get_am_inference
(
am
:
str
=
'fastspeech2_csmsc'
,
def
get_am_inference
(
am
:
str
=
'fastspeech2_csmsc'
,
am_config
:
CfgNode
=
None
,
am_config
:
CfgNode
=
None
,
...
@@ -180,30 +211,22 @@ def get_am_inference(am: str='fastspeech2_csmsc',
...
@@ -180,30 +211,22 @@ def get_am_inference(am: str='fastspeech2_csmsc',
with
open
(
phones_dict
,
"r"
)
as
f
:
with
open
(
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
vocab_size
=
len
(
phn_id
)
vocab_size
=
len
(
phn_id
)
print
(
"vocab_size:"
,
vocab_size
)
tone_size
=
None
tone_size
=
None
if
tones_dict
is
not
None
:
if
tones_dict
is
not
None
:
with
open
(
tones_dict
,
"r"
)
as
f
:
with
open
(
tones_dict
,
"r"
)
as
f
:
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_size
=
len
(
tone_id
)
tone_size
=
len
(
tone_id
)
print
(
"tone_size:"
,
tone_size
)
spk_num
=
None
spk_num
=
None
if
speaker_dict
is
not
None
:
if
speaker_dict
is
not
None
:
with
open
(
speaker_dict
,
'rt'
)
as
f
:
with
open
(
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_num
=
len
(
spk_id
)
spk_num
=
len
(
spk_id
)
print
(
"spk_num:"
,
spk_num
)
odim
=
am_config
.
n_mels
odim
=
am_config
.
n_mels
# model: {model_name}_{dataset}
# model: {model_name}_{dataset}
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
am_class
=
dynamic_import
(
am_name
,
model_alias
)
am_class
=
dynamic_import
(
am_name
,
model_alias
)
am_inference_class
=
dynamic_import
(
am_name
+
'_inference'
,
model_alias
)
am_inference_class
=
dynamic_import
(
am_name
+
'_inference'
,
model_alias
)
if
am_name
==
'fastspeech2'
:
if
am_name
==
'fastspeech2'
:
am
=
am_class
(
am
=
am_class
(
idim
=
vocab_size
,
odim
=
odim
,
spk_num
=
spk_num
,
**
am_config
[
"model"
])
idim
=
vocab_size
,
odim
=
odim
,
spk_num
=
spk_num
,
**
am_config
[
"model"
])
...
@@ -228,7 +251,6 @@ def get_am_inference(am: str='fastspeech2_csmsc',
...
@@ -228,7 +251,6 @@ def get_am_inference(am: str='fastspeech2_csmsc',
am_normalizer
=
ZScore
(
am_mu
,
am_std
)
am_normalizer
=
ZScore
(
am_mu
,
am_std
)
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
am_inference
.
eval
()
am_inference
.
eval
()
print
(
"acoustic model done!"
)
if
return_am
:
if
return_am
:
return
am_inference
,
am
return
am_inference
,
am
else
:
else
:
...
@@ -260,7 +282,6 @@ def get_voc_inference(
...
@@ -260,7 +282,6 @@ def get_voc_inference(
voc_normalizer
=
ZScore
(
voc_mu
,
voc_std
)
voc_normalizer
=
ZScore
(
voc_mu
,
voc_std
)
voc_inference
=
voc_inference_class
(
voc_normalizer
,
voc
)
voc_inference
=
voc_inference_class
(
voc_normalizer
,
voc
)
voc_inference
.
eval
()
voc_inference
.
eval
()
print
(
"voc done!"
)
return
voc_inference
return
voc_inference
...
@@ -342,9 +363,9 @@ def get_predictor(model_dir: Optional[os.PathLike]=None,
...
@@ -342,9 +363,9 @@ def get_predictor(model_dir: Optional[os.PathLike]=None,
def
get_am_output
(
def
get_am_output
(
input
:
str
,
input
:
str
,
am_predictor
,
am_predictor
:
paddle
.
nn
.
Layer
,
am
,
am
:
str
,
frontend
,
frontend
:
object
,
lang
:
str
=
'zh'
,
lang
:
str
=
'zh'
,
merge_sentences
:
bool
=
True
,
merge_sentences
:
bool
=
True
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
...
@@ -352,30 +373,23 @@ def get_am_output(
...
@@ -352,30 +373,23 @@ def get_am_output(
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
am_input_names
=
am_predictor
.
get_input_names
()
am_input_names
=
am_predictor
.
get_input_names
()
get_tone_ids
=
False
get_spk_id
=
False
get_spk_id
=
False
get_tone_ids
=
False
if
am_name
==
'speedyspeech'
:
if
am_name
==
'speedyspeech'
:
get_tone_ids
=
True
get_tone_ids
=
True
if
am_dataset
in
{
"aishell3"
,
"vctk"
}
and
speaker_dict
:
if
am_dataset
in
{
"aishell3"
,
"vctk"
}
and
speaker_dict
:
get_spk_id
=
True
get_spk_id
=
True
spk_id
=
np
.
array
([
spk_id
])
spk_id
=
np
.
array
([
spk_id
])
if
lang
==
'zh'
:
input_ids
=
frontend
.
get_input_ids
(
frontend_dict
=
run_frontend
(
input
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
frontend
=
frontend
,
phone_ids
=
input_ids
[
"phone_ids"
]
text
=
input
,
elif
lang
==
'en'
:
merge_sentences
=
merge_sentences
,
input_ids
=
frontend
.
get_input_ids
(
get_tone_ids
=
get_tone_ids
,
input
,
merge_sentences
=
merge_sentences
)
lang
=
lang
)
phone_ids
=
input_ids
[
"phone_ids"
]
elif
lang
==
'mix'
:
input_ids
=
frontend
.
get_input_ids
(
input
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
print
(
"lang should in {'zh', 'en', 'mix'}!"
)
if
get_tone_ids
:
if
get_tone_ids
:
tone_ids
=
input_ids
[
"tone_ids"
]
tone_ids
=
frontend_dict
[
'tone_ids'
]
tones
=
tone_ids
[
0
].
numpy
()
tones
=
tone_ids
[
0
].
numpy
()
tones_handle
=
am_predictor
.
get_input_handle
(
am_input_names
[
1
])
tones_handle
=
am_predictor
.
get_input_handle
(
am_input_names
[
1
])
tones_handle
.
reshape
(
tones
.
shape
)
tones_handle
.
reshape
(
tones
.
shape
)
...
@@ -384,6 +398,7 @@ def get_am_output(
...
@@ -384,6 +398,7 @@ def get_am_output(
spk_id_handle
=
am_predictor
.
get_input_handle
(
am_input_names
[
1
])
spk_id_handle
=
am_predictor
.
get_input_handle
(
am_input_names
[
1
])
spk_id_handle
.
reshape
(
spk_id
.
shape
)
spk_id_handle
.
reshape
(
spk_id
.
shape
)
spk_id_handle
.
copy_from_cpu
(
spk_id
)
spk_id_handle
.
copy_from_cpu
(
spk_id
)
phone_ids
=
frontend_dict
[
'phone_ids'
]
phones
=
phone_ids
[
0
].
numpy
()
phones
=
phone_ids
[
0
].
numpy
()
phones_handle
=
am_predictor
.
get_input_handle
(
am_input_names
[
0
])
phones_handle
=
am_predictor
.
get_input_handle
(
am_input_names
[
0
])
phones_handle
.
reshape
(
phones
.
shape
)
phones_handle
.
reshape
(
phones
.
shape
)
...
@@ -432,13 +447,13 @@ def get_streaming_am_output(input: str,
...
@@ -432,13 +447,13 @@ def get_streaming_am_output(input: str,
lang
:
str
=
'zh'
,
lang
:
str
=
'zh'
,
merge_sentences
:
bool
=
True
):
merge_sentences
:
bool
=
True
):
get_tone_ids
=
False
get_tone_ids
=
False
if
lang
==
'zh'
:
frontend_dict
=
run_frontend
(
input_ids
=
frontend
.
get_input_ids
(
frontend
=
frontend
,
input
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
text
=
input
,
phone_ids
=
input_ids
[
"phone_ids"
]
merge_sentences
=
merge_sentences
,
else
:
get_tone_ids
=
get_tone_ids
,
print
(
"lang should be 'zh' here!"
)
lang
=
lang
)
phone_ids
=
frontend_dict
[
'phone_ids'
]
phones
=
phone_ids
[
0
].
numpy
()
phones
=
phone_ids
[
0
].
numpy
()
am_encoder_infer_output
=
get_am_sublayer_output
(
am_encoder_infer_output
=
get_am_sublayer_output
(
am_encoder_infer_predictor
,
input
=
phones
)
am_encoder_infer_predictor
,
input
=
phones
)
...
@@ -455,26 +470,25 @@ def get_streaming_am_output(input: str,
...
@@ -455,26 +470,25 @@ def get_streaming_am_output(input: str,
# onnx
# onnx
def
get_sess
(
model_dir
:
Optional
[
os
.
PathLike
]
=
None
,
def
get_sess
(
model_path
:
Optional
[
os
.
PathLike
],
model_file
:
Optional
[
os
.
PathLike
]
=
None
,
device
:
str
=
'cpu'
,
device
:
str
=
'cpu'
,
cpu_threads
:
int
=
1
,
cpu_threads
:
int
=
1
,
use_trt
:
bool
=
False
):
use_trt
:
bool
=
False
):
model_dir
=
str
(
Path
(
model_dir
)
/
model_file
)
sess_options
=
ort
.
SessionOptions
()
sess_options
=
ort
.
SessionOptions
()
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
sess_options
.
execution_mode
=
ort
.
ExecutionMode
.
ORT_SEQUENTIAL
sess_options
.
execution_mode
=
ort
.
ExecutionMode
.
ORT_SEQUENTIAL
if
'gpu'
in
device
.
lower
():
if
device
==
"gpu"
:
device_id
=
int
(
device
.
split
(
':'
)[
1
])
if
len
(
device
.
split
(
':'
))
==
2
else
0
# fastspeech2/mb_melgan can't use trt now!
# fastspeech2/mb_melgan can't use trt now!
if
use_trt
:
if
use_trt
:
provider
s
=
[
'TensorrtExecutionProvider'
]
provider
_name
=
'TensorrtExecutionProvider'
else
:
else
:
providers
=
[
'CUDAExecutionProvider'
]
provider_name
=
'CUDAExecutionProvider'
elif
device
==
"cpu"
:
providers
=
[(
provider_name
,
{
'device_id'
:
device_id
})]
elif
device
.
lower
()
==
'cpu'
:
providers
=
[
'CPUExecutionProvider'
]
providers
=
[
'CPUExecutionProvider'
]
sess_options
.
intra_op_num_threads
=
cpu_threads
sess_options
.
intra_op_num_threads
=
cpu_threads
sess
=
ort
.
InferenceSession
(
sess
=
ort
.
InferenceSession
(
model_
dir
,
providers
=
providers
,
sess_options
=
sess_options
)
model_
path
,
providers
=
providers
,
sess_options
=
sess_options
)
return
sess
return
sess
paddlespeech/t2s/exps/synthesize_e2e.py
浏览文件 @
b9ade180
...
@@ -25,6 +25,7 @@ from paddlespeech.t2s.exps.syn_utils import get_am_inference
...
@@ -25,6 +25,7 @@ from paddlespeech.t2s.exps.syn_utils import get_am_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
from
paddlespeech.t2s.exps.syn_utils
import
run_frontend
from
paddlespeech.t2s.exps.syn_utils
import
voc_to_static
from
paddlespeech.t2s.exps.syn_utils
import
voc_to_static
...
@@ -49,6 +50,7 @@ def evaluate(args):
...
@@ -49,6 +50,7 @@ def evaluate(args):
lang
=
args
.
lang
,
lang
=
args
.
lang
,
phones_dict
=
args
.
phones_dict
,
phones_dict
=
args
.
phones_dict
,
tones_dict
=
args
.
tones_dict
)
tones_dict
=
args
.
tones_dict
)
print
(
"frontend done!"
)
# acoustic model
# acoustic model
am_name
=
args
.
am
[:
args
.
am
.
rindex
(
'_'
)]
am_name
=
args
.
am
[:
args
.
am
.
rindex
(
'_'
)]
...
@@ -62,13 +64,14 @@ def evaluate(args):
...
@@ -62,13 +64,14 @@ def evaluate(args):
phones_dict
=
args
.
phones_dict
,
phones_dict
=
args
.
phones_dict
,
tones_dict
=
args
.
tones_dict
,
tones_dict
=
args
.
tones_dict
,
speaker_dict
=
args
.
speaker_dict
)
speaker_dict
=
args
.
speaker_dict
)
print
(
"acoustic model done!"
)
# vocoder
# vocoder
voc_inference
=
get_voc_inference
(
voc_inference
=
get_voc_inference
(
voc
=
args
.
voc
,
voc
=
args
.
voc
,
voc_config
=
voc_config
,
voc_config
=
voc_config
,
voc_ckpt
=
args
.
voc_ckpt
,
voc_ckpt
=
args
.
voc_ckpt
,
voc_stat
=
args
.
voc_stat
)
voc_stat
=
args
.
voc_stat
)
print
(
"voc done!"
)
# whether dygraph to static
# whether dygraph to static
if
args
.
inference_dir
:
if
args
.
inference_dir
:
...
@@ -78,7 +81,6 @@ def evaluate(args):
...
@@ -78,7 +81,6 @@ def evaluate(args):
am
=
args
.
am
,
am
=
args
.
am
,
inference_dir
=
args
.
inference_dir
,
inference_dir
=
args
.
inference_dir
,
speaker_dict
=
args
.
speaker_dict
)
speaker_dict
=
args
.
speaker_dict
)
# vocoder
# vocoder
voc_inference
=
voc_to_static
(
voc_inference
=
voc_to_static
(
voc_inference
=
voc_inference
,
voc_inference
=
voc_inference
,
...
@@ -101,24 +103,13 @@ def evaluate(args):
...
@@ -101,24 +103,13 @@ def evaluate(args):
T
=
0
T
=
0
for
utt_id
,
sentence
in
sentences
:
for
utt_id
,
sentence
in
sentences
:
with
timer
()
as
t
:
with
timer
()
as
t
:
if
args
.
lang
==
'zh'
:
frontend_dict
=
run_frontend
(
input_ids
=
frontend
.
get_input_ids
(
frontend
=
frontend
,
sentence
,
text
=
sentence
,
merge_sentences
=
merge_sentences
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
get_tone_ids
=
get_tone_ids
,
phone_ids
=
input_ids
[
"phone_ids"
]
lang
=
args
.
lang
)
if
get_tone_ids
:
phone_ids
=
frontend_dict
[
'phone_ids'
]
tone_ids
=
input_ids
[
"tone_ids"
]
elif
args
.
lang
==
'en'
:
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
elif
args
.
lang
==
'mix'
:
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
print
(
"lang should in {'zh', 'en', 'mix'}!"
)
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
flags
=
0
flags
=
0
for
i
in
range
(
len
(
phone_ids
)):
for
i
in
range
(
len
(
phone_ids
)):
...
@@ -132,7 +123,7 @@ def evaluate(args):
...
@@ -132,7 +123,7 @@ def evaluate(args):
else
:
else
:
mel
=
am_inference
(
part_phone_ids
)
mel
=
am_inference
(
part_phone_ids
)
elif
am_name
==
'speedyspeech'
:
elif
am_name
==
'speedyspeech'
:
part_tone_ids
=
tone_ids
[
i
]
part_tone_ids
=
frontend_dict
[
'tone_ids'
]
[
i
]
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
spk_id
=
paddle
.
to_tensor
(
args
.
spk_id
)
spk_id
=
paddle
.
to_tensor
(
args
.
spk_id
)
mel
=
am_inference
(
part_phone_ids
,
part_tone_ids
,
mel
=
am_inference
(
part_phone_ids
,
part_tone_ids
,
...
...
paddlespeech/t2s/exps/synthesize_streaming.py
浏览文件 @
b9ade180
...
@@ -30,6 +30,7 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
...
@@ -30,6 +30,7 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_sentences
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
from
paddlespeech.t2s.exps.syn_utils
import
model_alias
from
paddlespeech.t2s.exps.syn_utils
import
model_alias
from
paddlespeech.t2s.exps.syn_utils
import
run_frontend
from
paddlespeech.t2s.exps.syn_utils
import
voc_to_static
from
paddlespeech.t2s.exps.syn_utils
import
voc_to_static
from
paddlespeech.t2s.utils
import
str2bool
from
paddlespeech.t2s.utils
import
str2bool
from
paddlespeech.utils.dynamic_import
import
dynamic_import
from
paddlespeech.utils.dynamic_import
import
dynamic_import
...
@@ -138,15 +139,13 @@ def evaluate(args):
...
@@ -138,15 +139,13 @@ def evaluate(args):
for
utt_id
,
sentence
in
sentences
:
for
utt_id
,
sentence
in
sentences
:
with
timer
()
as
t
:
with
timer
()
as
t
:
if
args
.
lang
==
'zh'
:
frontend_dict
=
run_frontend
(
input_ids
=
frontend
.
get_input_ids
(
frontend
=
frontend
,
sentence
,
text
=
sentence
,
merge_sentences
=
merge_sentences
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
get_tone_ids
=
get_tone_ids
,
lang
=
args
.
lang
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
frontend_dict
[
'phone_ids'
]
else
:
print
(
"lang should be 'zh' here!"
)
# merge_sentences=True here, so we only use the first item of phone_ids
# merge_sentences=True here, so we only use the first item of phone_ids
phone_ids
=
phone_ids
[
0
]
phone_ids
=
phone_ids
[
0
]
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
...
...
paddlespeech/t2s/frontend/mix_frontend.py
浏览文件 @
b9ade180
...
@@ -136,7 +136,8 @@ class MixFrontend():
...
@@ -136,7 +136,8 @@ class MixFrontend():
sentence
:
str
,
sentence
:
str
,
merge_sentences
:
bool
=
True
,
merge_sentences
:
bool
=
True
,
get_tone_ids
:
bool
=
False
,
get_tone_ids
:
bool
=
False
,
add_sp
:
bool
=
True
)
->
Dict
[
str
,
List
[
paddle
.
Tensor
]]:
add_sp
:
bool
=
True
,
to_tensor
:
bool
=
True
)
->
Dict
[
str
,
List
[
paddle
.
Tensor
]]:
sentences
=
self
.
_split
(
sentence
)
sentences
=
self
.
_split
(
sentence
)
phones_list
=
[]
phones_list
=
[]
...
@@ -152,11 +153,12 @@ class MixFrontend():
...
@@ -152,11 +153,12 @@ class MixFrontend():
input_ids
=
self
.
zh_frontend
.
get_input_ids
(
input_ids
=
self
.
zh_frontend
.
get_input_ids
(
content
,
content
,
merge_sentences
=
True
,
merge_sentences
=
True
,
get_tone_ids
=
get_tone_ids
)
get_tone_ids
=
get_tone_ids
,
to_tensor
=
to_tensor
)
elif
lang
==
"en"
:
elif
lang
==
"en"
:
input_ids
=
self
.
en_frontend
.
get_input_ids
(
input_ids
=
self
.
en_frontend
.
get_input_ids
(
content
,
merge_sentences
=
True
)
content
,
merge_sentences
=
True
,
to_tensor
=
to_tensor
)
phones_seg
.
append
(
input_ids
[
"phone_ids"
][
0
])
phones_seg
.
append
(
input_ids
[
"phone_ids"
][
0
])
if
add_sp
:
if
add_sp
:
...
...
paddlespeech/t2s/frontend/phonectic.py
浏览文件 @
b9ade180
...
@@ -82,8 +82,10 @@ class English(Phonetics):
...
@@ -82,8 +82,10 @@ class English(Phonetics):
phone_ids
=
[
self
.
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids
=
[
self
.
vocab_phones
[
item
]
for
item
in
phonemes
]
return
np
.
array
(
phone_ids
,
np
.
int64
)
return
np
.
array
(
phone_ids
,
np
.
int64
)
def
get_input_ids
(
self
,
sentence
:
str
,
def
get_input_ids
(
self
,
merge_sentences
:
bool
=
False
)
->
paddle
.
Tensor
:
sentence
:
str
,
merge_sentences
:
bool
=
False
,
to_tensor
:
bool
=
True
)
->
paddle
.
Tensor
:
result
=
{}
result
=
{}
sentences
=
self
.
text_normalizer
.
_split
(
sentence
,
lang
=
"en"
)
sentences
=
self
.
text_normalizer
.
_split
(
sentence
,
lang
=
"en"
)
phones_list
=
[]
phones_list
=
[]
...
@@ -112,7 +114,8 @@ class English(Phonetics):
...
@@ -112,7 +114,8 @@ class English(Phonetics):
for
part_phones_list
in
phones_list
:
for
part_phones_list
in
phones_list
:
phone_ids
=
self
.
_p2id
(
part_phones_list
)
phone_ids
=
self
.
_p2id
(
part_phones_list
)
phone_ids
=
paddle
.
to_tensor
(
phone_ids
)
if
to_tensor
:
phone_ids
=
paddle
.
to_tensor
(
phone_ids
)
temp_phone_ids
.
append
(
phone_ids
)
temp_phone_ids
.
append
(
phone_ids
)
result
[
"phone_ids"
]
=
temp_phone_ids
result
[
"phone_ids"
]
=
temp_phone_ids
return
result
return
result
...
...
paddlespeech/t2s/frontend/zh_frontend.py
浏览文件 @
b9ade180
...
@@ -303,15 +303,15 @@ class Frontend():
...
@@ -303,15 +303,15 @@ class Frontend():
print
(
"----------------------------"
)
print
(
"----------------------------"
)
return
phonemes
return
phonemes
def
get_input_ids
(
def
get_input_ids
(
self
,
self
,
sentence
:
str
,
sentence
:
str
,
merge_sentences
:
bool
=
True
,
merge_sentences
:
bool
=
Tru
e
,
get_tone_ids
:
bool
=
Fals
e
,
get_tone_ids
:
bool
=
False
,
robot
:
bool
=
False
,
robot
:
bool
=
False
,
print_info
:
bool
=
False
,
print_info
:
bool
=
False
,
add_blank
:
bool
=
False
,
add_blank
:
bool
=
False
,
blank_token
:
str
=
"<pad>"
,
blank_token
:
str
=
"<pad>"
)
->
Dict
[
str
,
List
[
paddle
.
Tensor
]]:
to_tensor
:
bool
=
True
)
->
Dict
[
str
,
List
[
paddle
.
Tensor
]]:
phonemes
=
self
.
get_phonemes
(
phonemes
=
self
.
get_phonemes
(
sentence
,
sentence
,
merge_sentences
=
merge_sentences
,
merge_sentences
=
merge_sentences
,
...
@@ -322,20 +322,22 @@ class Frontend():
...
@@ -322,20 +322,22 @@ class Frontend():
tones
=
[]
tones
=
[]
temp_phone_ids
=
[]
temp_phone_ids
=
[]
temp_tone_ids
=
[]
temp_tone_ids
=
[]
for
part_phonemes
in
phonemes
:
for
part_phonemes
in
phonemes
:
phones
,
tones
=
self
.
_get_phone_tone
(
phones
,
tones
=
self
.
_get_phone_tone
(
part_phonemes
,
get_tone_ids
=
get_tone_ids
)
part_phonemes
,
get_tone_ids
=
get_tone_ids
)
if
add_blank
:
if
add_blank
:
phones
=
insert_after_character
(
phones
,
blank_token
)
phones
=
insert_after_character
(
phones
,
blank_token
)
if
tones
:
if
tones
:
tone_ids
=
self
.
_t2id
(
tones
)
tone_ids
=
self
.
_t2id
(
tones
)
tone_ids
=
paddle
.
to_tensor
(
tone_ids
)
if
to_tensor
:
tone_ids
=
paddle
.
to_tensor
(
tone_ids
)
temp_tone_ids
.
append
(
tone_ids
)
temp_tone_ids
.
append
(
tone_ids
)
if
phones
:
if
phones
:
phone_ids
=
self
.
_p2id
(
phones
)
phone_ids
=
self
.
_p2id
(
phones
)
phone_ids
=
paddle
.
to_tensor
(
phone_ids
)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low
if
to_tensor
:
phone_ids
=
paddle
.
to_tensor
(
phone_ids
)
temp_phone_ids
.
append
(
phone_ids
)
temp_phone_ids
.
append
(
phone_ids
)
if
temp_tone_ids
:
if
temp_tone_ids
:
result
[
"tone_ids"
]
=
temp_tone_ids
result
[
"tone_ids"
]
=
temp_tone_ids
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录