synthesize_e2e.py 11.1 KB
Newer Older
小湉湉's avatar
小湉湉 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
H
Hui Zhang 已提交
16
from pprint import pprint
小湉湉's avatar
小湉湉 已提交
17 18 19 20

import paddle
import soundfile as sf
import yaml
21
from timer import timer
小湉湉's avatar
小湉湉 已提交
22 23
from yacs.config import CfgNode

小湉湉's avatar
小湉湉 已提交
24 25 26 27
from paddlespeech.t2s.exps.syn_utils import am_to_static
from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
L
liangym 已提交
28
from paddlespeech.t2s.exps.syn_utils import get_sentences_svs
小湉湉's avatar
小湉湉 已提交
29
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
小湉湉's avatar
小湉湉 已提交
30
from paddlespeech.t2s.exps.syn_utils import run_frontend
小湉湉's avatar
小湉湉 已提交
31
from paddlespeech.t2s.exps.syn_utils import voc_to_static
32
from paddlespeech.t2s.utils import str2bool
小湉湉's avatar
小湉湉 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49


def evaluate(args):

    # Init body.
    with open(args.am_config) as f:
        am_config = CfgNode(yaml.safe_load(f))
    with open(args.voc_config) as f:
        voc_config = CfgNode(yaml.safe_load(f))

    print("========Args========")
    print(yaml.safe_dump(vars(args)))
    print("========Config========")
    print(am_config)
    print(voc_config)

    # frontend
小湉湉's avatar
小湉湉 已提交
50 51 52
    frontend = get_frontend(
        lang=args.lang,
        phones_dict=args.phones_dict,
53
        tones_dict=args.tones_dict,
L
liangym 已提交
54
        pinyin_phone=args.pinyin_phone,
55
        use_rhy=args.use_rhy)
小湉湉's avatar
小湉湉 已提交
56
    print("frontend done!")
小湉湉's avatar
小湉湉 已提交
57 58

    # acoustic model
小湉湉's avatar
小湉湉 已提交
59 60 61 62 63 64 65 66 67
    am_name = args.am[:args.am.rindex('_')]
    am_dataset = args.am[args.am.rindex('_') + 1:]
    am_inference = get_am_inference(
        am=args.am,
        am_config=am_config,
        am_ckpt=args.am_ckpt,
        am_stat=args.am_stat,
        phones_dict=args.phones_dict,
        tones_dict=args.tones_dict,
L
liangym 已提交
68 69
        speaker_dict=args.speaker_dict,
        speech_stretchs=args.speech_stretchs, )
小湉湉's avatar
小湉湉 已提交
70
    print("acoustic model done!")
L
liangym 已提交
71

小湉湉's avatar
小湉湉 已提交
72
    # vocoder
小湉湉's avatar
小湉湉 已提交
73 74 75 76 77
    voc_inference = get_voc_inference(
        voc=args.voc,
        voc_config=voc_config,
        voc_ckpt=args.voc_ckpt,
        voc_stat=args.voc_stat)
小湉湉's avatar
小湉湉 已提交
78
    print("voc done!")
小湉湉's avatar
小湉湉 已提交
79 80 81

    # whether dygraph to static
    if args.inference_dir:
H
Hui Zhang 已提交
82
        print("convert am and voc to static model.")
小湉湉's avatar
小湉湉 已提交
83
        # acoustic model
小湉湉's avatar
小湉湉 已提交
84 85 86 87 88
        am_inference = am_to_static(
            am_inference=am_inference,
            am=args.am,
            inference_dir=args.inference_dir,
            speaker_dict=args.speaker_dict)
小湉湉's avatar
小湉湉 已提交
89
        # vocoder
小湉湉's avatar
小湉湉 已提交
90 91 92 93
        voc_inference = voc_to_static(
            voc_inference=voc_inference,
            voc=args.voc,
            inference_dir=args.inference_dir)
小湉湉's avatar
小湉湉 已提交
94 95 96

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
H
Hui Zhang 已提交
97

98
    merge_sentences = False
小湉湉's avatar
小湉湉 已提交
99 100 101 102
    # Avoid not stopping at the end of a sub sentence when tacotron2_ljspeech dygraph to static graph
    # but still not stopping in the end (NOTE by yuantian01 Feb 9 2022)
    if am_name == 'tacotron2':
        merge_sentences = True
103 104 105 106 107

    get_tone_ids = False
    if am_name == 'speedyspeech':
        get_tone_ids = True

H
Hui Zhang 已提交
108
    # wav samples
109
    N = 0
H
Hui Zhang 已提交
110
    # inference time cost
111
    T = 0
H
Hui Zhang 已提交
112 113

    # [(uid, text), ]
L
liangym 已提交
114 115 116 117
    if am_name == 'diffsinger':
        sentences = get_sentences_svs(text_file=args.text)
    else:
        sentences = get_sentences(text_file=args.text, lang=args.lang)
H
Hui Zhang 已提交
118 119
    pprint(f"inputs: {sentences}")

小湉湉's avatar
小湉湉 已提交
120
    for utt_id, sentence in sentences:
121
        with timer() as t:
L
liangym 已提交
122 123 124 125 126 127
            if am_name == "diffsinger":
                text = ""
                svs_input = sentence
            else:
                text = sentence
                svs_input = None
H
Hui Zhang 已提交
128 129

            # frontend
小湉湉's avatar
小湉湉 已提交
130 131
            frontend_dict = run_frontend(
                frontend=frontend,
L
liangym 已提交
132
                text=text,
小湉湉's avatar
小湉湉 已提交
133 134
                merge_sentences=merge_sentences,
                get_tone_ids=get_tone_ids,
L
liangym 已提交
135 136
                lang=args.lang,
                svs_input=svs_input)
小湉湉's avatar
小湉湉 已提交
137
            phone_ids = frontend_dict['phone_ids']
H
Hui Zhang 已提交
138 139
            # pprint(f"process: {utt_id} {phone_ids}")

140 141 142
            with paddle.no_grad():
                flags = 0
                for i in range(len(phone_ids)):
H
Hui Zhang 已提交
143
                    # sub phone, split by `sp` or punctuation.
144
                    part_phone_ids = phone_ids[i]
H
Hui Zhang 已提交
145

146 147 148
                    # acoustic model
                    if am_name == 'fastspeech2':
                        # multi speaker
149
                        if am_dataset in {"aishell3", "vctk", "mix", "canton"}:
H
Hui Zhang 已提交
150
                            # multi-speaker
151
                            spk_id = paddle.to_tensor([args.spk_id])
152 153
                            mel = am_inference(part_phone_ids, spk_id)
                        else:
H
Hui Zhang 已提交
154
                            # single-speaker
155 156
                            mel = am_inference(part_phone_ids)
                    elif am_name == 'speedyspeech':
小湉湉's avatar
小湉湉 已提交
157
                        part_tone_ids = frontend_dict['tone_ids'][i]
L
lym0302 已提交
158
                        if am_dataset in {"aishell3", "vctk", "mix"}:
H
Hui Zhang 已提交
159
                            # multi-speaker
160
                            spk_id = paddle.to_tensor([args.spk_id])
161 162 163
                            mel = am_inference(part_phone_ids, part_tone_ids,
                                               spk_id)
                        else:
H
Hui Zhang 已提交
164
                            # single-speaker
165 166
                            mel = am_inference(part_phone_ids, part_tone_ids)
                    elif am_name == 'tacotron2':
167
                        mel = am_inference(part_phone_ids)
L
liangym 已提交
168 169 170 171 172 173 174 175 176
                    elif am_name == 'diffsinger':
                        part_note_ids = frontend_dict['note_ids'][i]
                        part_note_durs = frontend_dict['note_durs'][i]
                        part_is_slurs = frontend_dict['is_slurs'][i]
                        mel = am_inference(
                            text=part_phone_ids,
                            note=part_note_ids,
                            note_dur=part_note_durs,
                            is_slur=part_is_slurs, )
H
Hui Zhang 已提交
177

178 179 180 181 182
                    # vocoder
                    wav = voc_inference(mel)
                    if flags == 0:
                        wav_all = wav
                        flags = 1
183
                    else:
184
                        wav_all = paddle.concat([wav_all, wav])
H
Hui Zhang 已提交
185

186 187 188
        wav = wav_all.numpy()
        N += wav.size
        T += t.elapse
H
Hui Zhang 已提交
189 190

        # samples per second
191
        speed = wav.size / t.elapse
H
Hui Zhang 已提交
192
        # generate one second wav need `RTF` seconds
193 194 195 196
        rtf = am_config.fs / speed
        print(
            f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
        )
H
Hui Zhang 已提交
197

小湉湉's avatar
小湉湉 已提交
198
        sf.write(
199
            str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs)
小湉湉's avatar
小湉湉 已提交
200
        print(f"{utt_id} done!")
H
Hui Zhang 已提交
201

202
    print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")
小湉湉's avatar
小湉湉 已提交
203 204


小湉湉's avatar
小湉湉 已提交
205
def parse_args():
小湉湉's avatar
小湉湉 已提交
206
    # parse args and config
小湉湉's avatar
小湉湉 已提交
207 208 209 210 211 212 213 214
    parser = argparse.ArgumentParser(
        description="Synthesize with acoustic model & vocoder")
    # acoustic model
    parser.add_argument(
        '--am',
        type=str,
        default='fastspeech2_csmsc',
        choices=[
L
liangym 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227
            'speedyspeech_csmsc',
            'speedyspeech_aishell3',
            'fastspeech2_csmsc',
            'fastspeech2_ljspeech',
            'fastspeech2_aishell3',
            'fastspeech2_vctk',
            'tacotron2_csmsc',
            'tacotron2_ljspeech',
            'fastspeech2_mix',
            'fastspeech2_canton',
            'fastspeech2_male-zh',
            'fastspeech2_male-en',
            'fastspeech2_male-mix',
L
liangym 已提交
228
            'diffsinger_opencpop',
小湉湉's avatar
小湉湉 已提交
229 230 231
        ],
        help='Choose acoustic model type of tts task.')
    parser.add_argument(
H
Hui Zhang 已提交
232
        '--am_config', type=str, default=None, help='Config of acoustic model.')
小湉湉's avatar
小湉湉 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
    parser.add_argument(
        '--am_ckpt',
        type=str,
        default=None,
        help='Checkpoint file of acoustic model.')
    parser.add_argument(
        "--am_stat",
        type=str,
        default=None,
        help="mean and standard deviation used to normalize spectrogram when training acoustic model."
    )
    parser.add_argument(
        "--phones_dict", type=str, default=None, help="phone vocabulary file.")
    parser.add_argument(
        "--tones_dict", type=str, default=None, help="tone vocabulary file.")
    parser.add_argument(
        "--speaker_dict", type=str, default=None, help="speaker id map file.")
    parser.add_argument(
        '--spk_id',
        type=int,
        default=0,
        help='spk id for multi speaker acoustic model')
    # vocoder
    parser.add_argument(
        '--voc',
        type=str,
        default='pwgan_csmsc',
        choices=[
261 262 263 264 265 266 267 268 269 270 271
            'pwgan_csmsc',
            'pwgan_ljspeech',
            'pwgan_aishell3',
            'pwgan_vctk',
            'mb_melgan_csmsc',
            'style_melgan_csmsc',
            'hifigan_csmsc',
            'hifigan_ljspeech',
            'hifigan_aishell3',
            'hifigan_vctk',
            'wavernn_csmsc',
L
liangym 已提交
272 273
            'pwgan_male',
            'hifigan_male',
L
liangym 已提交
274 275
            'pwgan_opencpop',
            'hifigan_opencpop',
小湉湉's avatar
小湉湉 已提交
276 277 278
        ],
        help='Choose vocoder type of tts task.')
    parser.add_argument(
H
Hui Zhang 已提交
279
        '--voc_config', type=str, default=None, help='Config of voc.')
小湉湉's avatar
小湉湉 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292
    parser.add_argument(
        '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
    parser.add_argument(
        "--voc_stat",
        type=str,
        default=None,
        help="mean and standard deviation used to normalize spectrogram when training voc."
    )
    # other
    parser.add_argument(
        '--lang',
        type=str,
        default='zh',
L
liangym 已提交
293
        choices=['zh', 'en', 'mix', 'canton', 'sing'],
L
lym0302 已提交
294
        help='Choose model language. zh or en or mix')
小湉湉's avatar
小湉湉 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307

    parser.add_argument(
        "--inference_dir",
        type=str,
        default=None,
        help="dir to save inference models")
    parser.add_argument(
        "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
    parser.add_argument(
        "--text",
        type=str,
        help="text to synthesize, a 'utt_id sentence' pair per line.")
    parser.add_argument("--output_dir", type=str, help="output dir.")
308 309 310 311 312
    parser.add_argument(
        "--use_rhy",
        type=str2bool,
        default=False,
        help="run rhythm frontend or not")
L
liangym 已提交
313 314 315 316 317 318 319 320 321 322 323
    parser.add_argument(
        "--pinyin_phone",
        type=str,
        default=None,
        help="pinyin to phone map file, using on sing_frontend.")
    parser.add_argument(
        "--speech_stretchs",
        type=str,
        default=None,
        help="The min and max values of the mel spectrum, using on diffusion of diffsinger."
    )
小湉湉's avatar
小湉湉 已提交
324 325

    args = parser.parse_args()
小湉湉's avatar
小湉湉 已提交
326 327 328 329 330
    return args


def main():
    args = parse_args()
小湉湉's avatar
小湉湉 已提交
331 332 333 334 335 336 337 338 339 340 341 342 343

    if args.ngpu == 0:
        paddle.set_device("cpu")
    elif args.ngpu > 0:
        paddle.set_device("gpu")
    else:
        print("ngpu should >= 0 !")

    evaluate(args)


if __name__ == "__main__":
    main()