encoder.py 2.7 KB
Newer Older
L
lifuchen 已提交
1 2
import paddle.fluid.dygraph as dg
import paddle.fluid as fluid
3
from parakeet.models.transformer_tts.utils import *
L
lifuchen 已提交
4
from parakeet.modules.multihead_attention import MultiheadAttention
L
lifuchen 已提交
5
from parakeet.modules.ffn import PositionwiseFeedForward
L
lifuchen 已提交
6
from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
L
lifuchen 已提交
7 8

class Encoder(dg.Layer):
L
lifuchen 已提交
9
    def __init__(self, embedding_size, num_hidden, num_head=4):
L
lifuchen 已提交
10 11
        super(Encoder, self).__init__()
        self.num_hidden = num_hidden
12
        self.num_head = num_head
L
lifuchen 已提交
13 14 15 16 17 18 19 20 21
        param = fluid.ParamAttr(initializer=fluid.initializer.Constant(value=1.0))
        self.alpha = self.create_parameter(shape=(1, ), attr=param, dtype='float32')
        self.pos_inp = get_sinusoid_encoding_table(1024, self.num_hidden, padding_idx=0)
        self.pos_emb = dg.Embedding(size=[1024, num_hidden],
                                 param_attr=fluid.ParamAttr(
                                     initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp),
                                     trainable=False))
        self.encoder_prenet = EncoderPrenet(embedding_size = embedding_size, 
                                            num_hidden = num_hidden, 
L
lifuchen 已提交
22
                                            use_cudnn=True)
L
lifuchen 已提交
23 24 25
        self.layers = [MultiheadAttention(num_hidden, num_hidden//num_head, num_hidden//num_head) for _ in range(3)]
        for i, layer in enumerate(self.layers):
            self.add_sublayer("self_attn_{}".format(i), layer)
26
        self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*num_head, filter_size=1, use_cudnn=True) for _ in range(3)]
L
lifuchen 已提交
27 28 29
        for i, layer in enumerate(self.ffns):
            self.add_sublayer("ffns_{}".format(i), layer)

30 31
    def forward(self, x, positional, mask=None, query_mask=None):
        
L
lifuchen 已提交
32
        if fluid.framework._dygraph_tracer()._train_mode:
33 34 35
            seq_len_key = x.shape[1]
            query_mask = layers.expand(query_mask, [self.num_head, 1, seq_len_key])
            mask = layers.expand(mask, [self.num_head, 1, 1])
L
lifuchen 已提交
36 37 38
        else:
            query_mask, mask = None, None
        
39
    
L
lifuchen 已提交
40 41 42 43 44 45 46 47 48
        # Encoder pre_network
        x = self.encoder_prenet(x) #(N,T,C)
        
        
        # Get positional encoding
        positional = self.pos_emb(positional) 
        
        x = positional * self.alpha + x #(N, T, C)
       
49
        
L
lifuchen 已提交
50
        # Positional dropout
51
        x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train')
L
lifuchen 已提交
52 53 54 55 56 57 58 59
        
        # Self attention encoder
        attentions = list()
        for layer, ffn in zip(self.layers, self.ffns):
            x, attention = layer(x, x, x, mask = mask, query_mask = query_mask)
            x = ffn(x)
            attentions.append(attention)

60
        return x, attentions