fastspeech.py 7.2 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14
import math
15
import numpy as np
L
lifuchen 已提交
16 17 18
import paddle.fluid.dygraph as dg
import paddle.fluid as fluid
from parakeet.g2p.text.symbols import symbols
19
from parakeet.models.transformer_tts.utils import *
L
lifuchen 已提交
20
from parakeet.models.transformer_tts.post_convnet import PostConvNet
L
lifuchen 已提交
21
from parakeet.models.fastspeech.length_regulator import LengthRegulator
L
lifuchen 已提交
22 23
from parakeet.models.fastspeech.encoder import Encoder
from parakeet.models.fastspeech.decoder import Decoder
L
lifuchen 已提交
24

L
lifuchen 已提交
25

L
lifuchen 已提交
26 27
class FastSpeech(dg.Layer):
    def __init__(self, cfg):
28 29 30 31 32
        """FastSpeech model.

        Args:
            cfg: the yaml configs used in FastSpeech model.
        """
L
lifuchen 已提交
33 34
        super(FastSpeech, self).__init__()

L
lifuchen 已提交
35 36 37 38 39 40
        self.encoder = Encoder(
            n_src_vocab=len(symbols) + 1,
            len_max_seq=cfg['max_seq_len'],
            n_layers=cfg['encoder_n_layer'],
            n_head=cfg['encoder_head'],
            d_k=cfg['fs_hidden_size'] // cfg['encoder_head'],
41
            d_q=cfg['fs_hidden_size'] // cfg['encoder_head'],
L
lifuchen 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
            d_model=cfg['fs_hidden_size'],
            d_inner=cfg['encoder_conv1d_filter_size'],
            fft_conv1d_kernel=cfg['fft_conv1d_filter'],
            fft_conv1d_padding=cfg['fft_conv1d_padding'],
            dropout=0.1)
        self.length_regulator = LengthRegulator(
            input_size=cfg['fs_hidden_size'],
            out_channels=cfg['duration_predictor_output_size'],
            filter_size=cfg['duration_predictor_filter_size'],
            dropout=cfg['dropout'])
        self.decoder = Decoder(
            len_max_seq=cfg['max_seq_len'],
            n_layers=cfg['decoder_n_layer'],
            n_head=cfg['decoder_head'],
            d_k=cfg['fs_hidden_size'] // cfg['decoder_head'],
57
            d_q=cfg['fs_hidden_size'] // cfg['decoder_head'],
L
lifuchen 已提交
58 59 60 61 62 63 64
            d_model=cfg['fs_hidden_size'],
            d_inner=cfg['decoder_conv1d_filter_size'],
            fft_conv1d_kernel=cfg['fft_conv1d_filter'],
            fft_conv1d_padding=cfg['fft_conv1d_padding'],
            dropout=0.1)
        self.weight = fluid.ParamAttr(
            initializer=fluid.initializer.XavierInitializer())
L
lifuchen 已提交
65
        k = math.sqrt(1 / cfg['fs_hidden_size'])
L
lifuchen 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
            low=-k, high=k))
        self.mel_linear = dg.Linear(
            cfg['fs_hidden_size'],
            cfg['audio']['num_mels'] * cfg['audio']['outputs_per_step'],
            param_attr=self.weight,
            bias_attr=self.bias, )
        self.postnet = PostConvNet(
            n_mels=cfg['audio']['num_mels'],
            num_hidden=512,
            filter_size=5,
            padding=int(5 / 2),
            num_conv=5,
            outputs_per_step=cfg['audio']['outputs_per_step'],
            use_cudnn=True,
            dropout=0.1,
            batchnorm_last=True)
L
lifuchen 已提交
83

L
lifuchen 已提交
84 85 86
    def forward(self,
                character,
                text_pos,
87 88
                enc_non_pad_mask,
                dec_non_pad_mask,
89
                mel_pos=None,
90 91
                enc_slf_attn_mask=None,
                dec_slf_attn_mask=None,
L
lifuchen 已提交
92 93
                length_target=None,
                alpha=1.0):
94
        """
95
        Compute mel output from text character.
96 97
        
        Args:
98 99 100 101 102 103 104 105 106 107 108 109 110
            character (Variable): shape(B, T_text), dtype float32, the input text characters, 
                where T_text means the timesteps of input characters, 
            text_pos (Variable): shape(B, T_text), dtype int64, the input text position. 
            mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position, 
                where T_mel means the timesteps of input spectrum,  
            enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
            dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
            enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, 
                the mask of input characters. Defaults to None.
            slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
                the mask of mel spectrum. Defaults to None.
            length_target (Variable, optional): shape(B, T_text), dtype int64, 
                the duration of phoneme compute from pretrained transformerTTS. Defaults to None. 
111 112
            alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence 
                mel, thereby controlling the voice speed. Defaults to 1.0.
113 114

        Returns:
115 116 117 118 119
            mel_output (Variable): shape(B, T_mel, C), the mel output before postnet.
            mel_output_postnet (Variable): shape(B, T_mel, C), the mel output after postnet.
            duration_predictor_output (Variable): shape(B, T_text), the duration of phoneme compute with duration predictor. 
            enc_slf_attn_list (List[Variable]): len(enc_n_layers), the encoder self attention list. 
            dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
120 121
        """

122 123 124 125 126
        encoder_output, enc_slf_attn_list = self.encoder(
            character,
            text_pos,
            enc_non_pad_mask,
            slf_attn_mask=enc_slf_attn_mask)
L
lifuchen 已提交
127
        if fluid.framework._dygraph_tracer()._train_mode:
L
lifuchen 已提交
128 129 130
            length_regulator_output, duration_predictor_output = self.length_regulator(
                encoder_output, target=length_target, alpha=alpha)
            decoder_output, dec_slf_attn_list = self.decoder(
131 132 133 134
                length_regulator_output,
                mel_pos,
                dec_non_pad_mask,
                slf_attn_mask=dec_slf_attn_mask)
L
lifuchen 已提交
135 136 137 138 139 140

            mel_output = self.mel_linear(decoder_output)
            mel_output_postnet = self.postnet(mel_output) + mel_output

            return mel_output, mel_output_postnet, duration_predictor_output, enc_slf_attn_list, dec_slf_attn_list
        else:
L
lifuchen 已提交
141 142
            length_regulator_output, decoder_pos = self.length_regulator(
                encoder_output, alpha=alpha)
143 144 145 146
            slf_attn_mask = get_triu_tensor(
                decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
            slf_attn_mask = fluid.layers.cast(
                dg.to_variable(slf_attn_mask == 0), np.float32)
147
            slf_attn_mask = dg.to_variable(slf_attn_mask)
148 149 150 151 152 153 154
            dec_non_pad_mask = fluid.layers.unsqueeze(
                (decoder_pos != 0).astype(np.float32), [-1])
            decoder_output, _ = self.decoder(
                length_regulator_output,
                decoder_pos,
                dec_non_pad_mask,
                slf_attn_mask=slf_attn_mask)
L
lifuchen 已提交
155 156 157
            mel_output = self.mel_linear(decoder_output)
            mel_output_postnet = self.postnet(mel_output) + mel_output

L
lifuchen 已提交
158
            return mel_output, mel_output_postnet