decoder.py 11.2 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
小湉湉's avatar
小湉湉 已提交
14
# Modified from espnet(https://github.com/espnet/espnet)
H
Hui Zhang 已提交
15 16 17 18 19 20 21 22 23 24 25
# 暂时删除了 dyminic conv
"""Decoder definition."""
import logging
from typing import Any
from typing import List
from typing import Tuple

import paddle
import paddle.nn.functional as F
from paddle import nn

26 27 28 29 30 31 32 33
from paddlespeech.t2s.modules.fastspeech2_transformer.attention import MultiHeadedAttention
from paddlespeech.t2s.modules.fastspeech2_transformer.decoder_layer import DecoderLayer
from paddlespeech.t2s.modules.fastspeech2_transformer.embedding import PositionalEncoding
from paddlespeech.t2s.modules.fastspeech2_transformer.lightconv import LightweightConvolution
from paddlespeech.t2s.modules.fastspeech2_transformer.mask import subsequent_mask
from paddlespeech.t2s.modules.fastspeech2_transformer.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.t2s.modules.fastspeech2_transformer.repeat import repeat
from paddlespeech.t2s.modules.layer_norm import LayerNorm
H
Hui Zhang 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293


class Decoder(nn.Layer):
    """Transfomer decoder module.

    Parameters
    ----------
    odim : int
        Output diminsion.
    self_attention_layer_type : str
        Self-attention layer type.
    attention_dim : int
        Dimention of attention.
    attention_heads : int
        The number of heads of multi head attention.
    conv_wshare : int
        The number of kernel of convolution. Only used in
        self_attention_layer_type == "lightconv*" or "dynamiconv*".
    conv_kernel_length : Union[int, str])
        Kernel size str of convolution
        (e.g. 71_71_71_71_71_71). Only used in self_attention_layer_type == "lightconv*" or "dynamiconv*".
    conv_usebias : bool
        Whether to use bias in convolution. Only used in
        self_attention_layer_type == "lightconv*" or "dynamiconv*".
    linear_units : int
        The number of units of position-wise feed forward.
    num_blocks : int
        The number of decoder blocks.
    dropout_rate : float
        Dropout rate.
    positional_dropout_rate : float
        Dropout rate after adding positional encoding.
    self_attention_dropout_rate : float
        Dropout rate in self-attention.
    src_attention_dropout_rate : float
        Dropout rate in source-attention.
    input_layer : (Union[str, paddle.nn.Layer])
        Input layer type.
    use_output_layer : bool
        Whether to use output layer.
    pos_enc_class : paddle.nn.Layer
        Positional encoding module class.
        `PositionalEncoding `or `ScaledPositionalEncoding`
    normalize_before : bool
        Whether to use layer_norm before the first block.
    concat_after : bool
        Whether to concat attention layer's input and output.
        if True, additional linear will be applied.
        i.e. x -> x + linear(concat(x, att(x)))
        if False, no additional linear will be applied. i.e. x -> x + att(x)

    """

    def __init__(
            self,
            odim,
            selfattention_layer_type="selfattn",
            attention_dim=256,
            attention_heads=4,
            conv_wshare=4,
            conv_kernel_length=11,
            conv_usebias=False,
            linear_units=2048,
            num_blocks=6,
            dropout_rate=0.1,
            positional_dropout_rate=0.1,
            self_attention_dropout_rate=0.0,
            src_attention_dropout_rate=0.0,
            input_layer="embed",
            use_output_layer=True,
            pos_enc_class=PositionalEncoding,
            normalize_before=True,
            concat_after=False, ):
        """Construct an Decoder object."""
        nn.Layer.__init__(self)
        if input_layer == "embed":
            self.embed = nn.Sequential(
                nn.Embedding(odim, attention_dim),
                pos_enc_class(attention_dim, positional_dropout_rate), )
        elif input_layer == "linear":
            self.embed = nn.Sequential(
                nn.Linear(odim, attention_dim),
                nn.LayerNorm(attention_dim),
                nn.Dropout(dropout_rate),
                nn.ReLU(),
                pos_enc_class(attention_dim, positional_dropout_rate), )
        elif isinstance(input_layer, nn.Layer):
            self.embed = nn.Sequential(
                input_layer,
                pos_enc_class(attention_dim, positional_dropout_rate))
        else:
            raise NotImplementedError(
                "only `embed` or paddle.nn.Layer is supported.")
        self.normalize_before = normalize_before

        # self-attention module definition
        if selfattention_layer_type == "selfattn":
            logging.info("decoder self-attention layer type = self-attention")
            decoder_selfattn_layer = MultiHeadedAttention
            decoder_selfattn_layer_args = [
                (attention_heads, attention_dim, self_attention_dropout_rate, )
            ] * num_blocks
        elif selfattention_layer_type == "lightconv":
            logging.info(
                "decoder self-attention layer type = lightweight convolution")
            decoder_selfattn_layer = LightweightConvolution
            decoder_selfattn_layer_args = [(
                conv_wshare, attention_dim, self_attention_dropout_rate,
                int(conv_kernel_length.split("_")[lnum]), True, conv_usebias, )
                                           for lnum in range(num_blocks)]

        self.decoders = repeat(
            num_blocks,
            lambda lnum: DecoderLayer(
                attention_dim,
                decoder_selfattn_layer(*decoder_selfattn_layer_args[lnum]),
                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
                concat_after, ), )
        self.selfattention_layer_type = selfattention_layer_type
        if self.normalize_before:
            self.after_norm = LayerNorm(attention_dim)
        if use_output_layer:
            self.output_layer = nn.Linear(attention_dim, odim)
        else:
            self.output_layer = None

    def forward(self, tgt, tgt_mask, memory, memory_mask):
        """Forward decoder.

        Parameters
        ----------
        tgt : paddle.Tensor
            Input token ids, int64 (#batch, maxlen_out) if input_layer == "embed". 
            In the other case, input tensor (#batch, maxlen_out, odim).
        tgt_mask : paddle.Tensor
            Input token mask (#batch, maxlen_out).
        memory : paddle.Tensor
            Encoded memory, float32 (#batch, maxlen_in, feat).
        memory_mask : paddle.Tensor
            Encoded memory mask (#batch, maxlen_in).

        Returns
        ----------
        paddle.Tensor
            Decoded token score before softmax (#batch, maxlen_out, odim)
            if use_output_layer is True. In the other case,final block outputs
            (#batch, maxlen_out, attention_dim).
        paddle.Tensor
            Score mask before softmax (#batch, maxlen_out).

        """
        x = self.embed(tgt)
        x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory,
                                                         memory_mask)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
            x = self.output_layer(x)
        return x, tgt_mask

    def forward_one_step(self, tgt, tgt_mask, memory, cache=None):
        """Forward one step.

        Parameters
        ----------
        tgt : paddle.Tensor
            Input token ids, int64 (#batch, maxlen_out).
        tgt_mask : paddle.Tensor
            Input token mask (#batch, maxlen_out).
        memory : paddle.Tensor
            Encoded memory, float32 (#batch, maxlen_in, feat).
        cache : (List[paddle.Tensor])
            List of cached tensors.
            Each tensor shape should be (#batch, maxlen_out - 1, size).
        Returns
        ----------
        paddle.Tensor
            Output tensor (batch, maxlen_out, odim).
        List[paddle.Tensor]
            List of cache tensors of each decoder layer.

        """
        x = self.embed(tgt)
        if cache is None:
            cache = [None] * len(self.decoders)
        new_cache = []
        for c, decoder in zip(cache, self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(
                x, tgt_mask, memory, None, cache=c)
            new_cache.append(x)

        if self.normalize_before:
            y = self.after_norm(x[:, -1])
        else:
            y = x[:, -1]
        if self.output_layer is not None:
            y = F.log_softmax(self.output_layer(y), axis=-1)

        return y, new_cache

    # beam search API (see ScorerInterface)
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys)).unsqueeze(0)
        if self.selfattention_layer_type != "selfattn":
            # TODO(karita): implement cache
            logging.warning(
                f"{self.selfattention_layer_type} does not support cached decoding."
            )
            state = None
        logp, state = self.forward_one_step(
            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
        return logp.squeeze(0), state

    # batch beam search API (see BatchScorerInterface)
    def batch_score(self,
                    ys: paddle.Tensor,
                    states: List[Any],
                    xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
        """Score new token batch (required).

        Parameters
        ----------
        ys : paddle.Tensor
            paddle.int64 prefix tokens (n_batch, ylen).
        states : List[Any]
            Scorer states for prefix tokens.
        xs : paddle.Tensor
            The encoder feature that generates ys (n_batch, xlen, n_feat).

        Returns
        ----------
        tuple[paddle.Tensor, List[Any]]
        Tuple ofbatchfied scores for next token with shape of `(n_batch, n_vocab)`
        and next state list for ys.

        """
        # merge states
        n_batch = len(ys)
        n_layers = len(self.decoders)
        if states[0] is None:
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                paddle.stack([states[b][i] for b in range(n_batch)])
                for i in range(n_layers)
            ]

        # batch decoding
        ys_mask = subsequent_mask(ys.shape[-1]).unsqueeze(0)
        logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)

        # transpose state of [layer, batch] into [batch, layer]
        state_list = [[states[i][b] for i in range(n_layers)]
                      for b in range(n_batch)]
        return logp, state_list