fastspeech.py 5.7 KB
Newer Older
L
lifuchen 已提交
1
import math
L
lifuchen 已提交
2 3 4
import paddle.fluid.dygraph as dg
import paddle.fluid as fluid
from parakeet.g2p.text.symbols import symbols
L
lifuchen 已提交
5
from parakeet.models.transformer_tts.post_convnet import PostConvNet
L
lifuchen 已提交
6
from parakeet.models.fastspeech.length_regulator import LengthRegulator
L
lifuchen 已提交
7 8
from parakeet.models.fastspeech.encoder import Encoder
from parakeet.models.fastspeech.decoder import Decoder
L
lifuchen 已提交
9 10 11 12 13 14 15

class FastSpeech(dg.Layer):
    def __init__(self, cfg):
        " FastSpeech"
        super(FastSpeech, self).__init__()

        self.encoder = Encoder(n_src_vocab=len(symbols)+1,
L
lifuchen 已提交
16 17 18 19 20 21 22 23 24
                               len_max_seq=cfg['max_seq_len'],
                               n_layers=cfg['encoder_n_layer'],
                               n_head=cfg['encoder_head'],
                               d_k=cfg['fs_hidden_size'] // cfg['encoder_head'],
                               d_v=cfg['fs_hidden_size'] // cfg['encoder_head'],
                               d_model=cfg['fs_hidden_size'],
                               d_inner=cfg['encoder_conv1d_filter_size'],
                               fft_conv1d_kernel=cfg['fft_conv1d_filter'], 
                               fft_conv1d_padding=cfg['fft_conv1d_padding'],
L
lifuchen 已提交
25
                               dropout=0.1)
L
lifuchen 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38
        self.length_regulator = LengthRegulator(input_size=cfg['fs_hidden_size'], 
                                                out_channels=cfg['duration_predictor_output_size'], 
                                                filter_size=cfg['duration_predictor_filter_size'], 
                                                dropout=cfg['dropout'])
        self.decoder = Decoder(len_max_seq=cfg['max_seq_len'],
                                n_layers=cfg['decoder_n_layer'],
                                n_head=cfg['decoder_head'],
                                d_k=cfg['fs_hidden_size'] // cfg['decoder_head'],
                                d_v=cfg['fs_hidden_size'] // cfg['decoder_head'],
                                d_model=cfg['fs_hidden_size'],
                                d_inner=cfg['decoder_conv1d_filter_size'],
                                fft_conv1d_kernel=cfg['fft_conv1d_filter'], 
                                fft_conv1d_padding=cfg['fft_conv1d_padding'],
L
lifuchen 已提交
39
                                dropout=0.1)
L
lifuchen 已提交
40
        self.weight = fluid.ParamAttr(initializer = fluid.initializer.XavierInitializer())
L
lifuchen 已提交
41
        k = math.sqrt(1 / cfg['fs_hidden_size'])
L
lifuchen 已提交
42
        self.bias = fluid.ParamAttr(initializer = fluid.initializer.Uniform(low=-k, high=k))
L
lifuchen 已提交
43 44
        self.mel_linear = dg.Linear(cfg['fs_hidden_size'], 
                                    cfg['audio']['num_mels']* cfg['audio']['outputs_per_step'],
L
lifuchen 已提交
45 46
                                    param_attr = self.weight,
                                    bias_attr = self.bias,)
L
lifuchen 已提交
47
        self.postnet = PostConvNet(n_mels=cfg['audio']['num_mels'],
L
lifuchen 已提交
48 49 50 51
                 num_hidden=512,
                 filter_size=5,
                 padding=int(5 / 2),
                 num_conv=5,
L
lifuchen 已提交
52
                 outputs_per_step=cfg['audio']['outputs_per_step'],
L
lifuchen 已提交
53
                 use_cudnn=True,
L
lifuchen 已提交
54 55
                 dropout=0.1,
                 batchnorm_last=True)
L
lifuchen 已提交
56 57

    def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0):
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        """
        FastSpeech model.
        
        Args:
            character (Variable): Shape(B, T_text), dtype: float32. The input text
                characters. T_text means the timesteps of input characters.
            text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
                position. T_text means the timesteps of input characters.
            mel_pos (Variable, optional): Shape(B, T_mel),
                dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
            length_target (Variable, optional): Shape(B, T_text),
                dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
            alpha (Constant): 
                dtype: float32. The hyperparameter to determine the length of the expanded sequence 
                mel, thereby controlling the voice speed.

        Returns:
            mel_output (Variable), Shape(B, mel_T, C), the mel output before postnet.
            mel_output_postnet (Variable), Shape(B, mel_T, C), the mel output after postnet.
            duration_predictor_output (Variable), Shape(B, text_T), the duration of phoneme compute 
            with duration predictor.
            enc_slf_attn_list (Variable), Shape(B, text_T, text_T), the encoder self attention list.
            dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
        """

L
lifuchen 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96
        encoder_output, non_pad_mask, enc_slf_attn_list = self.encoder(character, text_pos)
        if fluid.framework._dygraph_tracer()._train_mode:
            
            length_regulator_output, duration_predictor_output = self.length_regulator(encoder_output,
                                                                                       target=length_target,
                                                                                       alpha=alpha)
            decoder_output, dec_slf_attn_list = self.decoder(length_regulator_output, mel_pos)

            mel_output = self.mel_linear(decoder_output)
            mel_output_postnet = self.postnet(mel_output) + mel_output

            return mel_output, mel_output_postnet, duration_predictor_output, enc_slf_attn_list, dec_slf_attn_list
        else:
            length_regulator_output, decoder_pos = self.length_regulator(encoder_output, alpha=alpha)
L
lifuchen 已提交
97
            decoder_output, _ = self.decoder(length_regulator_output, decoder_pos)
L
lifuchen 已提交
98 99 100 101
            mel_output = self.mel_linear(decoder_output)
            mel_output_postnet = self.postnet(mel_output) + mel_output

            return mel_output, mel_output_postnet