提交 1a9e5961 编写于 作者: 小湉湉's avatar 小湉湉

fix fastspeech2 multi speaker to static, test=tts

上级 4a133619
......@@ -257,6 +257,7 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=exp/default/test_e2e \
--phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \
--speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \
--spk_id=0
--spk_id=0 \
--inference_dir=exp/default/inference
```
......@@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--spk_id=0
--spk_id=0 \
--inference_dir=${train_output_path}/inference
......@@ -240,13 +240,14 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--am_ckpt=fastspeech2_nosil_vctk_ckpt_0.5/snapshot_iter_66200.pdz \
--am_stat=fastspeech2_nosil_vctk_ckpt_0.5/speech_stats.npy \
--voc=pwgan_vctk \
--voc_config=pwg_vctk_ckpt_0.5/pwg_default.yaml \
--voc_ckpt=pwg_vctk_ckpt_0.5/pwg_snapshot_iter_1000000.pdz \
--voc_stat=pwg_vctk_ckpt_0.5/pwg_stats.npy \
--voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \
--voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \
--voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \
--lang=en \
--text=${BIN_DIR}/../sentences_en.txt \
--output_dir=exp/default/test_e2e \
--phones_dict=fastspeech2_nosil_vctk_ckpt_0.5/phone_id_map.txt \
--speaker_dict=fastspeech2_nosil_vctk_ckpt_0.5/speaker_id_map.txt \
--spk_id=0
--spk_id=0 \
--inference_dir=exp/default/inference
```
......@@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--spk_id=0
--spk_id=0 \
--inference_dir=${train_output_path}/inference
......@@ -14,9 +14,11 @@
import argparse
from pathlib import Path
import numpy
import soundfile as sf
from paddle import inference
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
......@@ -29,20 +31,38 @@ def main():
'--am',
type=str,
default='fastspeech2_csmsc',
choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'],
choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_aishell3',
'fastspeech2_vctk'
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--tones_dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
'--spk_id',
type=int,
default=0,
help='spk id for multi speaker acoustic model')
# voc
parser.add_argument(
'--voc',
type=str,
default='pwgan_csmsc',
choices=['pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc'],
choices=[
'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3',
'pwgan_vctk'
],
help='Choose vocoder type of tts task.')
# other
parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
parser.add_argument(
"--text",
type=str,
......@@ -53,8 +73,12 @@ def main():
args, _ = parser.parse_known_args()
frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
# frontend
if args.lang == 'zh':
frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
elif args.lang == 'en':
frontend = English(phone_vocab_path=args.phones_dict)
print("frontend done!")
# model: {model_name}_{dataset}
......@@ -83,30 +107,52 @@ def main():
print("in new inference")
# construct dataset for evaluation
sentences = []
with open(args.text, 'rt') as f:
for line in f:
items = line.strip().split()
utt_id = items[0]
sentence = "".join(items[1:])
if args.lang == 'zh':
sentence = "".join(items[1:])
elif args.lang == 'en':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence))
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
get_spk_id = True
spk_id = numpy.array([args.spk_id])
am_input_names = am_predictor.get_input_names()
print("am_input_names:", am_input_names)
merge_sentences = True
for utt_id, sentence in sentences:
input_ids = frontend.get_input_ids(
sentence, merge_sentences=True, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
tones = tone_ids[0].numpy()
tones_handle = am_predictor.get_input_handle(am_input_names[1])
tones_handle.reshape(tones.shape)
tones_handle.copy_from_cpu(tones)
if get_spk_id:
spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
spk_id_handle.reshape(spk_id.shape)
spk_id_handle.copy_from_cpu(spk_id)
phones = phone_ids[0].numpy()
phones_handle = am_predictor.get_input_handle(am_input_names[0])
phones_handle.reshape(phones.shape)
......
......@@ -159,9 +159,16 @@ def evaluate(args):
# acoustic model
if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
print(
"Haven't test dygraph to static for multi speaker fastspeech2 now!"
)
am_inference = jit.to_static(
am_inference,
input_spec=[
InputSpec([-1], dtype=paddle.int64),
InputSpec([1], dtype=paddle.int64)
])
paddle.jit.save(am_inference,
os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load(
os.path.join(args.inference_dir, args.am))
else:
am_inference = jit.to_static(
am_inference,
......
......@@ -781,7 +781,7 @@ class FastSpeech2(nn.Layer):
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, hs.shape[1], -1])
shape=[-1, paddle.shape(hs)[1], -1])
hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))
else:
raise NotImplementedError("support only add or concat.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册