fft_block.py 3.2 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15 16 17 18 19
import numpy as np
import math
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
import paddle.fluid as fluid
from parakeet.modules.multihead_attention import MultiheadAttention
L
lifuchen 已提交
20
from parakeet.modules.ffn import PositionwiseFeedForward
L
lifuchen 已提交
21

L
lifuchen 已提交
22

L
lifuchen 已提交
23
class FFTBlock(dg.Layer):
L
lifuchen 已提交
24 25 26 27 28
    def __init__(self,
                 d_model,
                 d_inner,
                 n_head,
                 d_k,
29
                 d_q,
L
lifuchen 已提交
30 31 32
                 filter_size,
                 padding,
                 dropout=0.2):
33 34 35 36 37 38 39 40 41 42 43 44
        """Feed forward structure based on self-attention.

        Args:
            d_model (int): the dim of hidden layer in multihead attention.
            d_inner (int): the dim of hidden layer in ffn.
            n_head (int): the head number of multihead attention.
            d_k (int): the dim of key in multihead attention.
            d_q (int): the dim of query in multihead attention.
            filter_size (int): the conv kernel size.
            padding (int): the conv padding size.
            dropout (float, optional): dropout probability. Defaults to 0.2.
        """
L
lifuchen 已提交
45
        super(FFTBlock, self).__init__()
L
lifuchen 已提交
46 47 48
        self.slf_attn = MultiheadAttention(
            d_model,
            d_k,
49
            d_q,
L
lifuchen 已提交
50 51 52 53 54 55 56 57 58 59
            num_head=n_head,
            is_bias=True,
            dropout=dropout,
            is_concat=False)
        self.pos_ffn = PositionwiseFeedForward(
            d_model,
            d_inner,
            filter_size=filter_size,
            padding=padding,
            dropout=dropout)
L
lifuchen 已提交
60

61
    def forward(self, enc_input, non_pad_mask, slf_attn_mask=None):
L
lifuchen 已提交
62
        """
63
        Feed forward block of FastSpeech
L
lifuchen 已提交
64 65
        
        Args:
66 67 68 69 70
            enc_input (Variable): shape(B, T, C), dtype float32, the embedding characters input, 
                where T means the timesteps of input.   
            non_pad_mask (Variable): shape(B, T, 1), dtype int64, the mask of sequence.
            slf_attn_mask (Variable, optional): shape(B, len_q, len_k), dtype int64, the mask of self attention,
                where len_q means the sequence length of query and len_k means the sequence length of key. Defaults to None. 
71
                     
L
lifuchen 已提交
72
        Returns:
73 74
            output (Variable): shape(B, T, C), the output after self-attention & ffn. 
            slf_attn (Variable): shape(B * n_head, T, T), the self attention.
L
lifuchen 已提交
75
        """
L
lifuchen 已提交
76 77
        output, slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
78

L
lifuchen 已提交
79 80 81 82 83
        output *= non_pad_mask

        output = self.pos_ffn(output)
        output *= non_pad_mask

L
lifuchen 已提交
84
        return output, slf_attn