decoder.py 5.3 KB
Newer Older
L
lifuchen 已提交
1
import math
L
lifuchen 已提交
2 3
import paddle.fluid.dygraph as dg
import paddle.fluid as fluid
4
from parakeet.models.transformer_tts.utils import *
L
lifuchen 已提交
5
from parakeet.modules.multihead_attention import MultiheadAttention
L
lifuchen 已提交
6
from parakeet.modules.ffn import PositionwiseFeedForward
L
lifuchen 已提交
7 8
from parakeet.models.transformer_tts.prenet import PreNet
from parakeet.models.transformer_tts.post_convnet import PostConvNet
L
lifuchen 已提交
9

L
lifuchen 已提交
10 11 12 13
class Decoder(dg.Layer):
    def __init__(self, num_hidden, config, num_head=4):
        super(Decoder, self).__init__()
        self.num_hidden = num_hidden
14
        self.num_head = num_head
L
lifuchen 已提交
15 16 17 18 19 20 21 22 23
        param = fluid.ParamAttr()
        self.alpha = self.create_parameter(shape=(1,), attr=param, dtype='float32',
                        default_initializer = fluid.initializer.ConstantInitializer(value=1.0))
        self.pos_inp = get_sinusoid_encoding_table(1024, self.num_hidden, padding_idx=0)
        self.pos_emb = dg.Embedding(size=[1024, num_hidden],
                                 padding_idx=0,
                                 param_attr=fluid.ParamAttr(
                                     initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp),
                                     trainable=False))
L
lifuchen 已提交
24
        self.decoder_prenet = PreNet(input_size = config['audio']['num_mels'], 
L
lifuchen 已提交
25 26 27
                                            hidden_size = num_hidden * 2, 
                                            output_size = num_hidden, 
                                            dropout_rate=0.2)
L
lifuchen 已提交
28 29 30 31
        k = math.sqrt(1 / num_hidden)
        self.linear = dg.Linear(num_hidden, num_hidden,
                                param_attr=fluid.ParamAttr(initializer = fluid.initializer.XavierInitializer()),
                                bias_attr=fluid.ParamAttr(initializer = fluid.initializer.Uniform(low=-k, high=k)))
L
lifuchen 已提交
32 33 34 35 36 37 38 39 40 41

        self.selfattn_layers = [MultiheadAttention(num_hidden, num_hidden//num_head, num_hidden//num_head) for _ in range(3)]
        for i, layer in enumerate(self.selfattn_layers):
            self.add_sublayer("self_attn_{}".format(i), layer)
        self.attn_layers = [MultiheadAttention(num_hidden, num_hidden//num_head, num_hidden//num_head) for _ in range(3)]
        for i, layer in enumerate(self.attn_layers):
            self.add_sublayer("attn_{}".format(i), layer)
        self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*num_head, filter_size=1) for _ in range(3)]
        for i, layer in enumerate(self.ffns):
            self.add_sublayer("ffns_{}".format(i), layer)
L
lifuchen 已提交
42
        self.mel_linear = dg.Linear(num_hidden, config['audio']['num_mels'] * config['audio']['outputs_per_step'],
L
lifuchen 已提交
43 44 45 46 47
                                param_attr=fluid.ParamAttr(initializer = fluid.initializer.XavierInitializer()),
                                bias_attr=fluid.ParamAttr(initializer = fluid.initializer.Uniform(low=-k, high=k)))
        self.stop_linear = dg.Linear(num_hidden, 1,
                                  param_attr=fluid.ParamAttr(initializer = fluid.initializer.XavierInitializer()),
                                  bias_attr=fluid.ParamAttr(initializer = fluid.initializer.Uniform(low=-k, high=k)))
L
lifuchen 已提交
48

L
lifuchen 已提交
49
        self.postconvnet = PostConvNet(config['audio']['num_mels'], config['hidden_size'], 
L
lifuchen 已提交
50
                                       filter_size = 5, padding = 4, num_conv=5, 
L
lifuchen 已提交
51
                                       outputs_per_step=config['audio']['outputs_per_step'], 
52
                                       use_cudnn=True)
L
lifuchen 已提交
53

54
    def forward(self, key, value, query, positional,  mask, m_mask=None, m_self_mask=None, zero_mask=None):
L
lifuchen 已提交
55 56 57 58

        # get decoder mask with triangular matrix
        
        if fluid.framework._dygraph_tracer()._train_mode:
59 60 61 62 63
            m_mask = layers.expand(m_mask, [self.num_head, 1, key.shape[1]])
            m_self_mask = layers.expand(m_self_mask, [self.num_head, 1, query.shape[1]])
            mask = layers.expand(mask, [self.num_head, 1, 1])
            zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1])

L
lifuchen 已提交
64
        else:
65 66
            m_mask, m_self_mask, zero_mask = None, None, None
        
L
lifuchen 已提交
67 68 69 70 71 72 73 74 75 76 77 78

        # Decoder pre-network
        query = self.decoder_prenet(query)
        
        # Centered position
        query = self.linear(query)

        # Get position embedding
        positional = self.pos_emb(positional)
        query = positional * self.alpha + query

        #positional dropout
79 80
        query = fluid.layers.dropout(query, 0.1, dropout_implementation='upscale_in_train')
       
L
lifuchen 已提交
81 82 83 84 85

        # Attention decoder-decoder, encoder-decoder
        selfattn_list = list()
        attn_list = list()
        
86
        
L
lifuchen 已提交
87
        for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers, self.ffns):
88
            query, attn_dec = selfattn(query, query, query, mask = mask, query_mask = m_self_mask)
L
lifuchen 已提交
89 90 91 92
            query, attn_dot = attn(key, value, query, mask = zero_mask, query_mask = m_mask)
            query = ffn(query)
            selfattn_list.append(attn_dec)
            attn_list.append(attn_dot)
93 94
            
        
L
lifuchen 已提交
95 96 97 98 99 100 101 102 103 104 105 106
        # Mel linear projection
        mel_out = self.mel_linear(query)
        # Post Mel Network
        out = self.postconvnet(mel_out)
        out = mel_out + out
        
        # Stop tokens
        stop_tokens = self.stop_linear(query)
        stop_tokens = layers.squeeze(stop_tokens, [-1])
        stop_tokens = layers.sigmoid(stop_tokens)

        return mel_out, out, attn_list, stop_tokens, selfattn_list