fft_block.py 2.5 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15 16 17 18 19
import numpy as np
import math
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
import paddle.fluid as fluid
from parakeet.modules.multihead_attention import MultiheadAttention
L
lifuchen 已提交
20
from parakeet.modules.ffn import PositionwiseFeedForward
L
lifuchen 已提交
21

L
lifuchen 已提交
22

L
lifuchen 已提交
23
class FFTBlock(dg.Layer):
L
lifuchen 已提交
24 25 26 27 28 29 30 31 32
    def __init__(self,
                 d_model,
                 d_inner,
                 n_head,
                 d_k,
                 d_v,
                 filter_size,
                 padding,
                 dropout=0.2):
L
lifuchen 已提交
33
        super(FFTBlock, self).__init__()
L
lifuchen 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47
        self.slf_attn = MultiheadAttention(
            d_model,
            d_k,
            d_v,
            num_head=n_head,
            is_bias=True,
            dropout=dropout,
            is_concat=False)
        self.pos_ffn = PositionwiseFeedForward(
            d_model,
            d_inner,
            filter_size=filter_size,
            padding=padding,
            dropout=dropout)
L
lifuchen 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

    def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
        """
        Feed Forward Transformer block in FastSpeech.
        
        Args:
            enc_input (Variable): Shape(B, T, C), dtype: float32. The embedding characters input. 
                T means the timesteps of input.
            non_pad_mask (Variable): Shape(B, T, 1), dtype: int64. The mask of sequence.
            slf_attn_mask (Variable): Shape(B, len_q, len_k), dtype: int64. The mask of self attention. 
                len_q means the sequence length of query, len_k means the sequence length of key.

        Returns:
            output (Variable), Shape(B, T, C), the output after self-attention & ffn.
            slf_attn (Variable), Shape(B * n_head, T, T), the self attention.
        """
L
lifuchen 已提交
64 65
        output, slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
L
lifuchen 已提交
66 67 68 69 70
        output *= non_pad_mask

        output = self.pos_ffn(output)
        output *= non_pad_mask

L
lifuchen 已提交
71
        return output, slf_attn