basic_transformer.py 7.5 KB
Newer Older
1 2 3 4 5 6
import math

import torch
from torch import nn

from pytorch_widedeep.wdtypes import Union, Tensor, Optional
7
from pytorch_widedeep.utils.general_utils import Alias
8 9 10 11 12 13
from pytorch_widedeep.models.tabular.transformers._encoders import (
    TransformerEncoder,
)


class Transformer(nn.Module):
14 15 16 17
    r"""Basic Encoder-Only Transformer Model for text classification/regression.
    As all other models in the library this model can be used as the
    `deeptext` component of a Wide & Deep model or independently by itself.

18 19 20 21 22 23
    :information_source: **NOTE**:
    This model is introduced in the context of recommendation systems and
    thought for sequences of any nature (e.g. items). It can, of course,
    still be used for text. However, at this stage, we have decided to not
    include the possibility of loading pretrained word vectors since we aim
    to integrate the library wit Huggingface in the (hopefully) near future
24 25 26 27 28 29 30 31 32 33

    Parameters
    ----------
    vocab_size: int
        Number of words in the vocabulary
    input_dim: int
        Dimension of the token embeddings

        Param aliases: `embed_dim`, `d_model`. <br/>

34 35
    seq_length: int, Optional, default = None
        Input sequence length
36 37 38 39 40 41 42 43 44 45 46 47 48 49
    n_heads: int, default = 8
        Number of attention heads per Transformer block
    n_blocks: int, default = 4
        Number of Transformer blocks
    attn_dropout: float, default = 0.2
        Dropout that will be applied to the Multi-Head Attention layers
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
    ff_factor: float, default = 4
        Multiplicative factor applied to the first layer of the FF network in
        each Transformer block, This is normally set to 4.
    activation: str, default = "gelu"
        Transformer Encoder activation function. _'tanh'_, _'relu'_,
        _'leaky_relu'_, _'gelu'_, _'geglu'_ and _'reglu'_ are supported
50 51
    padding_idx: int, default = 0
        index of the padding token in the padded-tokenised sequences.
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    with_cls_token: bool, default = False
        Boolean indicating if a `'[CLS]'` token is included in the tokenized
        sequences. If present, the final hidden state corresponding to this
        token is used as the aggregated representation for classification and
        regression tasks. **NOTE**: if included in the tokenized sequences it
        must be inserted as the first token in the sequences.
    with_pos_encoding: bool, default = True
        Boolean indicating if positional encoding will be used
    pos_encoding_dropout: float, default = 0.1
        Positional encoding dropout
    pos_encoder: nn.Module, Optional, default = None
        This model uses by default a standard positional encoding approach.
        However, any custom positional encoder can also be used and pass to
        the Transformer model via the 'pos_encoder' parameter

    Attributes
    ----------
    embedding: nn.Module
        Standard token embedding layer
    pos_encoder: nn.Module
        Positional Encoder
    encoder: nn.Module
        Sequence of Transformer blocks
75 76 77 78 79 80 81 82

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import Transformer
    >>> X_text = torch.cat((torch.zeros([5,1]), torch.empty(5, 4).random_(1,4)), axis=1)
    >>> model = Transformer(vocab_size=4, seq_length=5, input_dim=8, n_heads=1, n_blocks=1)
    >>> out = model(X_text)
83 84 85
    """

    @Alias("input_dim", ["embed_dim", "d_model"])
86
    @Alias("seq_length", ["max_length", "maxlen"])
87 88 89
    def __init__(
        self,
        vocab_size: int,
90
        seq_length: int,
91
        input_dim: int,
92 93 94 95
        n_heads: int,
        n_blocks: int,
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.1,
96
        ff_factor: int = 4,
97
        activation: str = "gelu",
98 99
        use_linear_attention: bool = False,
        use_flash_attention: bool = False,
100
        padding_idx: int = 0,
101 102
        with_cls_token: bool = False,
        *,  # from here on pos encoding args
103 104 105 106 107 108
        with_pos_encoding: bool = True,
        pos_encoding_dropout: float = 0.1,
        pos_encoder: Optional[nn.Module] = None,
    ):
        super().__init__()

109
        self.input_dim = input_dim
110
        self.seq_length = seq_length
111 112 113 114
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
115
        self.ff_factor = ff_factor
116
        self.activation = activation
117 118
        self.use_linear_attention = use_linear_attention
        self.use_flash_attention = use_flash_attention
119
        self.padding_idx = padding_idx
120
        self.with_cls_token = with_cls_token
121 122 123
        self.with_pos_encoding = with_pos_encoding
        self.pos_encoding_dropout = pos_encoding_dropout

124 125 126
        self.embedding = nn.Embedding(
            vocab_size, input_dim, padding_idx=self.padding_idx
        )
127 128 129 130 131

        if with_pos_encoding:
            if pos_encoder is not None:
                self.pos_encoder: Union[
                    nn.Module, nn.Identity, PositionalEncoding
132
                ] = pos_encoder
133 134
            else:
                self.pos_encoder = PositionalEncoding(
135
                    input_dim, pos_encoding_dropout, seq_length
136 137 138 139 140 141 142 143 144
                )
        else:
            self.pos_encoder = nn.Identity()

        self.encoder = nn.Sequential()
        for i in range(n_blocks):
            self.encoder.add_module(
                "transformer_block" + str(i),
                TransformerEncoder(
145
                    input_dim,
146 147 148 149
                    n_heads,
                    False,  # use_qkv_bias
                    attn_dropout,
                    ff_dropout,
150
                    ff_factor,
151
                    activation,
152 153
                    use_linear_attention,
                    use_flash_attention,
154 155 156 157
                ),
            )

    def forward(self, X: Tensor) -> Tensor:
158
        x = self.embedding(X.long())
159
        x = self.pos_encoder(x)
160 161 162 163 164 165 166 167 168
        x = self.encoder(x)
        if self.with_cls_token:
            x = x[:, 0, :]
        else:
            x = x.flatten(1)
        return x

    @property
    def output_dim(self) -> int:
169 170 171 172 173
        if self.with_cls_token:
            output_dim = self.input_dim
        else:
            output_dim = self.input_dim * self.seq_length
        return output_dim
174 175 176


class PositionalEncoding(nn.Module):
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    """Positional Encoding copied and pasted directly from [The Beginners'
    Tutorial]
    (https://pytorch.org/tutorials/beginner/transformer_tutorial.html) at the
    Pytorch site. Here is simply adapated so that the input sequence length
    must be specified and in our implementation the input tensor dimensions
    are arranged as `[batch_size, seq_len, embedding_dim]` instead of `
    [seq_len, batch_size, embedding_dim]` , as in the before mentioned
    tutorial

    Parameters
    ----------
    input_dim: int
        Dimension of the token embeddings
    dropout: float
        Positional encoding dropout
    seq_length: int
        input sequence length

    """

    def __init__(self, input_dim: int, dropout: float, seq_length: int):
198 199 200 201 202
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(seq_length).unsqueeze(1)
        div_term = torch.exp(
203
            torch.arange(0, input_dim, 2) * (-math.log(10000.0) / input_dim)
204
        )
205 206 207
        pe = torch.zeros(1, seq_length, input_dim)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
208 209 210 211
        self.register_buffer("pe", pe)

    def forward(self, X: Tensor) -> Tensor:
        return self.dropout(X + self.pe)