From 1a9e59612a9124ffc3f97aae57f4d24792cdb9cf Mon Sep 17 00:00:00 2001 From: TianYuan Date: Tue, 18 Jan 2022 03:53:27 +0000 Subject: [PATCH] fix fastspeech2 multi speaker to static, test=tts --- examples/aishell3/tts3/README.md | 3 +- .../aishell3/tts3/local/synthesize_e2e.sh | 3 +- examples/vctk/tts3/README.md | 9 +-- examples/vctk/tts3/local/synthesize_e2e.sh | 3 +- paddlespeech/t2s/exps/inference.py | 66 ++++++++++++++++--- paddlespeech/t2s/exps/synthesize_e2e.py | 13 +++- .../t2s/models/fastspeech2/fastspeech2.py | 2 +- 7 files changed, 78 insertions(+), 21 deletions(-) diff --git a/examples/aishell3/tts3/README.md b/examples/aishell3/tts3/README.md index 2538e8f9..281ad836 100644 --- a/examples/aishell3/tts3/README.md +++ b/examples/aishell3/tts3/README.md @@ -257,6 +257,7 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --output_dir=exp/default/test_e2e \ --phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \ --speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \ - --spk_id=0 + --spk_id=0 \ + --inference_dir=exp/default/inference ``` diff --git a/examples/aishell3/tts3/local/synthesize_e2e.sh b/examples/aishell3/tts3/local/synthesize_e2e.sh index d0d92585..60e1a5ce 100755 --- a/examples/aishell3/tts3/local/synthesize_e2e.sh +++ b/examples/aishell3/tts3/local/synthesize_e2e.sh @@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ --speaker_dict=dump/speaker_id_map.txt \ - --spk_id=0 + --spk_id=0 \ + --inference_dir=${train_output_path}/inference diff --git a/examples/vctk/tts3/README.md b/examples/vctk/tts3/README.md index 74c1086a..157949d1 100644 --- a/examples/vctk/tts3/README.md +++ b/examples/vctk/tts3/README.md @@ -240,13 +240,14 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --am_ckpt=fastspeech2_nosil_vctk_ckpt_0.5/snapshot_iter_66200.pdz \ --am_stat=fastspeech2_nosil_vctk_ckpt_0.5/speech_stats.npy \ --voc=pwgan_vctk \ - --voc_config=pwg_vctk_ckpt_0.5/pwg_default.yaml \ - --voc_ckpt=pwg_vctk_ckpt_0.5/pwg_snapshot_iter_1000000.pdz \ - --voc_stat=pwg_vctk_ckpt_0.5/pwg_stats.npy \ + --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \ + --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \ --lang=en \ --text=${BIN_DIR}/../sentences_en.txt \ --output_dir=exp/default/test_e2e \ --phones_dict=fastspeech2_nosil_vctk_ckpt_0.5/phone_id_map.txt \ --speaker_dict=fastspeech2_nosil_vctk_ckpt_0.5/speaker_id_map.txt \ - --spk_id=0 + --spk_id=0 \ + --inference_dir=exp/default/inference ``` diff --git a/examples/vctk/tts3/local/synthesize_e2e.sh b/examples/vctk/tts3/local/synthesize_e2e.sh index 51bb9e19..60d56d1c 100755 --- a/examples/vctk/tts3/local/synthesize_e2e.sh +++ b/examples/vctk/tts3/local/synthesize_e2e.sh @@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ --speaker_dict=dump/speaker_id_map.txt \ - --spk_id=0 + --spk_id=0 \ + --inference_dir=${train_output_path}/inference diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index e1d5306c..2c9b51f9 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -14,9 +14,11 @@ import argparse from pathlib import Path +import numpy import soundfile as sf from paddle import inference +from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend @@ -29,20 +31,38 @@ def main(): '--am', type=str, default='fastspeech2_csmsc', - choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'], + choices=[ + 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_aishell3', + 'fastspeech2_vctk' + ], help='Choose acoustic model type of tts task.') parser.add_argument( "--phones_dict", type=str, default=None, help="phone vocabulary file.") parser.add_argument( "--tones_dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') # voc parser.add_argument( '--voc', type=str, default='pwgan_csmsc', - choices=['pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc'], + choices=[ + 'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3', + 'pwgan_vctk' + ], help='Choose vocoder type of tts task.') # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') parser.add_argument( "--text", type=str, @@ -53,8 +73,12 @@ def main(): args, _ = parser.parse_known_args() - frontend = Frontend( - phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) + # frontend + if args.lang == 'zh': + frontend = Frontend( + phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) + elif args.lang == 'en': + frontend = English(phone_vocab_path=args.phones_dict) print("frontend done!") # model: {model_name}_{dataset} @@ -83,30 +107,52 @@ def main(): print("in new inference") + # construct dataset for evaluation + sentences = [] with open(args.text, 'rt') as f: for line in f: items = line.strip().split() utt_id = items[0] - sentence = "".join(items[1:]) + if args.lang == 'zh': + sentence = "".join(items[1:]) + elif args.lang == 'en': + sentence = " ".join(items[1:]) sentences.append((utt_id, sentence)) get_tone_ids = False if am_name == 'speedyspeech': get_tone_ids = True + if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + get_spk_id = True + spk_id = numpy.array([args.spk_id]) am_input_names = am_predictor.get_input_names() - + print("am_input_names:", am_input_names) + merge_sentences = True for utt_id, sentence in sentences: - input_ids = frontend.get_input_ids( - sentence, merge_sentences=True, get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + if get_tone_ids: tone_ids = input_ids["tone_ids"] tones = tone_ids[0].numpy() tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle.reshape(tones.shape) tones_handle.copy_from_cpu(tones) - + if get_spk_id: + spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) + spk_id_handle.reshape(spk_id.shape) + spk_id_handle.copy_from_cpu(spk_id) phones = phone_ids[0].numpy() phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle.reshape(phones.shape) diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 15ed1e4d..9b503213 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -159,9 +159,16 @@ def evaluate(args): # acoustic model if am_name == 'fastspeech2': if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: - print( - "Haven't test dygraph to static for multi speaker fastspeech2 now!" - ) + am_inference = jit.to_static( + am_inference, + input_spec=[ + InputSpec([-1], dtype=paddle.int64), + InputSpec([1], dtype=paddle.int64) + ]) + paddle.jit.save(am_inference, + os.path.join(args.inference_dir, args.am)) + am_inference = paddle.jit.load( + os.path.join(args.inference_dir, args.am)) else: am_inference = jit.to_static( am_inference, diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 6bb651a0..dc136ffd 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -781,7 +781,7 @@ class FastSpeech2(nn.Layer): elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( - shape=[-1, hs.shape[1], -1]) + shape=[-1, paddle.shape(hs)[1], -1]) hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1)) else: raise NotImplementedError("support only add or concat.") -- GitLab