fft_block.py 2.6 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15 16 17 18 19
import numpy as np
import math
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
import paddle.fluid as fluid
from parakeet.modules.multihead_attention import MultiheadAttention
L
lifuchen 已提交
20
from parakeet.modules.ffn import PositionwiseFeedForward
L
lifuchen 已提交
21

L
lifuchen 已提交
22

L
lifuchen 已提交
23
class FFTBlock(dg.Layer):
L
lifuchen 已提交
24 25 26 27 28 29 30 31 32
    def __init__(self,
                 d_model,
                 d_inner,
                 n_head,
                 d_k,
                 d_v,
                 filter_size,
                 padding,
                 dropout=0.2):
L
lifuchen 已提交
33
        super(FFTBlock, self).__init__()
L
lifuchen 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47
        self.slf_attn = MultiheadAttention(
            d_model,
            d_k,
            d_v,
            num_head=n_head,
            is_bias=True,
            dropout=dropout,
            is_concat=False)
        self.pos_ffn = PositionwiseFeedForward(
            d_model,
            d_inner,
            filter_size=filter_size,
            padding=padding,
            dropout=dropout)
L
lifuchen 已提交
48

49
    def forward(self, enc_input, non_pad_mask, slf_attn_mask=None):
L
lifuchen 已提交
50 51 52 53
        """
        Feed Forward Transformer block in FastSpeech.
        
        Args:
54 55 56 57 58 59 60 61
            enc_input (Variable): The embedding characters input. 
                Shape: (B, T, C), T means the timesteps of input, dtype: float32.   
            non_pad_mask (Variable): The mask of sequence.
                Shape: (B, T, 1), dtype: int64.
            slf_attn_mask (Variable, optional): The mask of self attention. Defaults to None.
                Shape(B, len_q, len_k), len_q means the sequence length of query, 
                len_k means the sequence length of key, dtype: int64.   
                     
L
lifuchen 已提交
62
        Returns:
63 64
            output (Variable), the output after self-attention & ffn. Shape: (B, T, C).
            slf_attn (Variable), the self attention. Shape: (B * n_head, T, T),
L
lifuchen 已提交
65
        """
L
lifuchen 已提交
66 67
        output, slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
68

L
lifuchen 已提交
69 70 71 72 73
        output *= non_pad_mask

        output = self.pos_ffn(output)
        output *= non_pad_mask

L
lifuchen 已提交
74
        return output, slf_attn