synthesis.py 5.8 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14
import os
走神的阿圆's avatar
走神的阿圆 已提交
15
from visualdl import LogWriter
16
from scipy.io.wavfile import write
L
lifuchen 已提交
17
from collections import OrderedDict
L
lifuchen 已提交
18
import argparse
L
lifuchen 已提交
19
from pprint import pprint
L
lifuchen 已提交
20
from ruamel import yaml
L
lifuchen 已提交
21
from matplotlib import cm
L
lifuchen 已提交
22 23 24 25 26
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
from parakeet.g2p.en import text_to_sequence
from parakeet import audio
L
lifuchen 已提交
27
from parakeet.models.fastspeech.fastspeech import FastSpeech
28
from parakeet.models.transformer_tts.utils import *
29 30
from parakeet.models.wavenet import WaveNet, UpsampleNet
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
31 32
from parakeet.modules import weight_norm
from parakeet.models.waveflow import WaveFlowModule
33
from parakeet.utils.layer_tools import freeze
34
from parakeet.utils import io
L
lifuchen 已提交
35

L
lifuchen 已提交
36

37 38 39
def add_config_options_to_parser(parser):
    parser.add_argument("--config", type=str, help="path of the config file")
    parser.add_argument(
40 41
        "--vocoder",
        type=str,
42 43
        default="griffin-lim",
        choices=['griffin-lim', 'waveflow'],
44 45 46
        help="vocoder method")
    parser.add_argument(
        "--config_vocoder", type=str, help="path of the vocoder config file")
47 48 49 50 51 52 53
    parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
    parser.add_argument(
        "--alpha",
        type=float,
        default=1,
        help="determine the length of the expanded sequence mel, controlling the voice speed."
    )
L
lifuchen 已提交
54

55
    parser.add_argument(
56
        "--checkpoint", type=str, help="fastspeech checkpoint for synthesis")
57
    parser.add_argument(
58
        "--checkpoint_vocoder",
59
        type=str,
60
        help="vocoder checkpoint for synthesis")
L
lifuchen 已提交
61

62 63 64 65 66
    parser.add_argument(
        "--output",
        type=str,
        default="synthesis",
        help="path to save experiment results")
L
lifuchen 已提交
67

L
lifuchen 已提交
68

69 70 71
def synthesis(text_input, args):
    local_rank = dg.parallel.Env().local_rank
    place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())
72
    fluid.enable_dygraph(place)
73 74

    with open(args.config) as f:
L
lifuchen 已提交
75
        cfg = yaml.load(f, Loader=yaml.Loader)
L
lifuchen 已提交
76

77 78 79 80
    # tensorboard
    if not os.path.exists(args.output):
        os.mkdir(args.output)

走神的阿圆's avatar
走神的阿圆 已提交
81
    writer = LogWriter(os.path.join(args.output, 'log'))
L
lifuchen 已提交
82

83 84 85 86 87 88 89 90 91 92 93
    model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
    # Load parameters.
    global_step = io.load_parameters(
        model=model, checkpoint_path=args.checkpoint)
    model.eval()

    text = np.asarray(text_to_sequence(text_input))
    text = np.expand_dims(text, axis=0)
    pos_text = np.arange(1, text.shape[1] + 1)
    pos_text = np.expand_dims(pos_text, axis=0)

94 95
    text = dg.to_variable(text).astype(np.int64)
    pos_text = dg.to_variable(pos_text).astype(np.int64)
96 97 98

    _, mel_output_postnet = model(text, pos_text, alpha=args.alpha)

99
    if args.vocoder == 'griffin-lim':
100
        #synthesis use griffin-lim
101
        wav = synthesis_with_griffinlim(mel_output_postnet, cfg['audio'])
102 103 104 105 106
    elif args.vocoder == 'waveflow':
        wav = synthesis_with_waveflow(mel_output_postnet, args,
                                      args.checkpoint_vocoder, place)
    else:
        print(
107
            'vocoder error, we only support griffinlim and waveflow, but recevied %s.'
108 109 110
            % args.vocoder)

    writer.add_audio(text_input + '(' + args.vocoder + ')', wav, 0,
111 112 113 114 115
                     cfg['audio']['sr'])
    if not os.path.exists(os.path.join(args.output, 'samples')):
        os.mkdir(os.path.join(args.output, 'samples'))
    write(
        os.path.join(
116
            os.path.join(args.output, 'samples'), args.vocoder + '.wav'),
117 118
        cfg['audio']['sr'], wav)
    print("Synthesis completed !!!")
L
lifuchen 已提交
119 120
    writer.close()

L
lifuchen 已提交
121

122
def synthesis_with_griffinlim(mel_output, cfg):
123 124 125
    mel_output = fluid.layers.transpose(
        fluid.layers.squeeze(mel_output, [0]), [1, 0])
    mel_output = np.exp(mel_output.numpy())
126 127 128 129 130
    basis = librosa.filters.mel(cfg['sr'],
                                cfg['n_fft'],
                                cfg['num_mels'],
                                fmin=cfg['fmin'],
                                fmax=cfg['fmax'])
131 132 133 134
    inv_basis = np.linalg.pinv(basis)
    spec = np.maximum(1e-10, np.dot(inv_basis, mel_output))

    wav = librosa.core.griffinlim(
135 136 137
        spec**cfg['power'],
        hop_length=cfg['hop_length'],
        win_length=cfg['win_length'])
138 139 140 141 142 143 144 145 146 147 148

    return wav


def synthesis_with_waveflow(mel_output, args, checkpoint, place):

    fluid.enable_dygraph(place)
    args.config = args.config_vocoder
    args.use_fp16 = False
    config = io.add_yaml_config_to_args(args)

149
    mel_spectrogram = fluid.layers.transpose(mel_output, [0, 2, 1])
150 151 152 153 154 155 156 157 158 159 160

    # Build model.
    waveflow = WaveFlowModule(config)
    io.load_parameters(model=waveflow, checkpoint_path=checkpoint)
    for layer in waveflow.sublayers():
        if isinstance(layer, weight_norm.WeightNormWrapper):
            layer.remove_weight_norm()

    # Run model inference.
    wav = waveflow.synthesize(mel_spectrogram, sigma=config.sigma)
    return wav.numpy()[0]
161 162


L
lifuchen 已提交
163
if __name__ == '__main__':
164
    parser = argparse.ArgumentParser(description="Synthesis model")
L
lifuchen 已提交
165
    add_config_options_to_parser(parser)
L
lifuchen 已提交
166
    args = parser.parse_args()
167
    pprint(vars(args))
168 169 170
    synthesis(
        "Don't argue with the people of strong determination, because they may change the fact!",
        args)