未验证 提交 ad239eb4 编写于 作者: 小湉湉's avatar 小湉湉 提交者: GitHub

[TTS]add VITS inference (#2972)

上级 ff8c56b0
../../tts3/local/export2lite.sh
\ No newline at end of file
#!/bin/bash
train_output_path=$1
add_blank=$2
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/inference.py \
--inference_dir=${train_output_path}/inference \
--am=vits_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt \
--add-blank=${add_blank}
fi
\ No newline at end of file
#!/bin/bash
train_output_path=$1
add_blank=$2
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/../lite_predict.py \
--inference_dir=${train_output_path}/pdlite \
--am=vits_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/lite_infer_out \
--phones_dict=dump/phone_id_map.txt \
--add-blank=${add_blank}
fi
......@@ -35,3 +35,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} ${add_blank}|| exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} ${add_blank}|| exit -1
fi
......@@ -19,15 +19,15 @@ def get_lite_predictor(model_dir: Optional[os.PathLike]=None,
return predictor
def get_lite_am_output(
input: str,
am_predictor,
am: str,
frontend: object,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0, ):
def get_lite_am_output(input: str,
am_predictor,
am: str,
frontend: object,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0,
add_blank: bool=False):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
get_spk_id = False
......@@ -43,7 +43,8 @@ def get_lite_am_output(
text=input,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
lang=lang)
lang=lang,
add_blank=add_blank, )
if get_tone_ids:
tone_ids = frontend_dict['tone_ids']
......
......@@ -284,7 +284,8 @@ def run_frontend(frontend: object,
merge_sentences: bool=False,
get_tone_ids: bool=False,
lang: str='zh',
to_tensor: bool=True):
to_tensor: bool=True,
add_blank: bool=False):
outs = dict()
if lang == 'zh':
input_ids = {}
......@@ -300,7 +301,8 @@ def run_frontend(frontend: object,
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
to_tensor=to_tensor,
add_blank=add_blank)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
......@@ -576,15 +578,15 @@ def get_predictor(
return predictor
def get_am_output(
input: str,
am_predictor: paddle.nn.Layer,
am: str,
frontend: object,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0, ):
def get_am_output(input: str,
am_predictor: paddle.nn.Layer,
am: str,
frontend: object,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0,
add_blank: bool=False):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names()
......@@ -601,7 +603,8 @@ def get_am_output(
text=input,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
lang=lang)
lang=lang,
add_blank=add_blank, )
if get_tone_ids:
tone_ids = frontend_dict['tone_ids']
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import paddle
import soundfile as sf
from timer import timer
from paddlespeech.t2s.exps.syn_utils import get_am_output
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_predictor
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.utils import str2bool
def parse_args():
parser = argparse.ArgumentParser(
description="Paddle Infernce with acoustic model & vocoder.")
# acoustic model
parser.add_argument(
'--am',
type=str,
default='vits_csmsc',
choices=['vits_csmsc', 'vits_aishell3'],
help='Choose acoustic model type of tts task.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
'--spk_id',
type=int,
default=0,
help='spk id for multi speaker acoustic model')
# other
parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en or mix')
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line")
parser.add_argument(
"--add-blank",
type=str2bool,
default=True,
help="whether to add blank between phones")
parser.add_argument(
"--inference_dir", type=str, help="dir to save inference models")
parser.add_argument("--output_dir", type=str, help="output dir")
# inference
parser.add_argument(
"--use_trt",
type=str2bool,
default=False,
help="whether to use TensorRT or not in GPU", )
parser.add_argument(
"--use_mkldnn",
type=str2bool,
default=False,
help="whether to use MKLDNN or not in CPU.", )
parser.add_argument(
"--precision",
type=str,
default='fp32',
choices=['fp32', 'fp16', 'bf16', 'int8'],
help="mode of running")
parser.add_argument(
"--device",
default="gpu",
choices=["gpu", "cpu"],
help="Device selected for inference.", )
parser.add_argument('--cpu_threads', type=int, default=1)
args, _ = parser.parse_known_args()
return args
# only inference for models trained with csmsc now
def main():
args = parse_args()
paddle.set_device(args.device)
# frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
# am_predictor
am_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
params_file=args.am + ".pdiparams",
device=args.device,
use_trt=args.use_trt,
use_mkldnn=args.use_mkldnn,
cpu_threads=args.cpu_threads,
precision=args.precision)
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True
add_blank = args.add_blank
# vits's fs is 22050
fs = 22050
# warmup
for utt_id, sentence in sentences[:3]:
with timer() as t:
wav = get_am_output(
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
speaker_dict=args.speaker_dict,
spk_id=args.spk_id,
add_blank=add_blank)
speed = wav.size / t.elapse
rtf = fs / speed
print(
f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
print("warm up done!")
N = 0
T = 0
for utt_id, sentence in sentences:
with timer() as t:
wav = get_am_output(
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
speaker_dict=args.speaker_dict,
spk_id=args.spk_id,
add_blank=add_blank)
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
rtf = fs / speed
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=fs)
print(
f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
print(f"{utt_id} done!")
print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
if __name__ == "__main__":
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import soundfile as sf
from timer import timer
from paddlespeech.t2s.exps.lite_syn_utils import get_lite_am_output
from paddlespeech.t2s.exps.lite_syn_utils import get_lite_predictor
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
def parse_args():
parser = argparse.ArgumentParser(
description="Paddle Infernce with acoustic model & vocoder.")
# acoustic model
parser.add_argument(
'--am',
type=str,
default='vits_csmsc',
choices=[
'vits_csmsc',
'vits_aishell3',
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
'--spk_id',
type=int,
default=0,
help='spk id for multi speaker acoustic model')
# other
parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en or mix')
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line")
parser.add_argument(
"--add-blank",
type=str2bool,
default=True,
help="whether to add blank between phones")
parser.add_argument(
"--inference_dir", type=str, help="dir to save inference models")
parser.add_argument("--output_dir", type=str, help="output dir")
args, _ = parser.parse_known_args()
return args
# only inference for models trained with csmsc now
def main():
args = parse_args()
# frontend
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor
am_predictor = get_lite_predictor(
model_dir=args.inference_dir, model_file=args.am + "_x86.nb")
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True
add_blank = args.add_blank
fs = 22050
# warmup
for utt_id, sentence in sentences[:3]:
with timer() as t:
wav = get_lite_am_output(
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
speaker_dict=args.speaker_dict,
spk_id=args.spk_id,
add_blank=add_blank)
speed = wav.size / t.elapse
rtf = fs / speed
print(
f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
print("warm up done!")
N = 0
T = 0
for utt_id, sentence in sentences:
with timer() as t:
wav = get_lite_am_output(
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
speaker_dict=args.speaker_dict,
spk_id=args.spk_id,
add_blank=add_blank)
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
rtf = fs / speed
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=fs)
print(
f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
print(f"{utt_id} done!")
print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册