提交 7a80f3e7 编写于 作者: J Javier

implemented linear and 'standard' attention in a functional way so they are...

implemented linear and 'standard' attention in a functional way so they are available via parameters passed to the main multi head attention class
上级 eb02f25f
......@@ -6,14 +6,14 @@ https://github.com/lucidrains
import math
import warnings
from enum import Enum
from typing import List, ContextManager
from typing import ContextManager
import torch
import einops
import torch.nn.functional as F
from torch import nn, einsum
from pytorch_widedeep.wdtypes import Tensor, Optional
from pytorch_widedeep.wdtypes import List, Tuple, Tensor, Optional
from pytorch_widedeep.models._get_activation_fn import get_activation_fn
......@@ -67,6 +67,61 @@ class AddNorm(nn.Module):
return self.ln(X + self.dropout(sublayer(X)))
def _standard_attention(
k: Tensor, q: Tensor, v: Tensor, head_dim: int, dropout: float
) -> Tuple[Tensor, Tensor]:
"""'Standard' multihead attention implemenation from [Attention Is All You
Need](https://arxiv.org/abs/1706.03762)
"""
# b: batch size
# s: seq length
# l: target sequence length
# m: used to refer indistinctively to s or l
# h: number of attention heads,
# d and e: head_dim
# Normalised Query, Key dot product. Fraction term in their Eq 1
scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(head_dim)
# Softmax
attn_weights = scores.softmax(dim=-1)
# Attention(Q, K, V ) (with dropout) in their Eq 1
attn_output = einsum(
"b h s l, b h l d -> b h s d", nn.Dropout(dropout)(attn_weights), v
)
return attn_weights, attn_output
def _linear_attention(k: Tensor, q: Tensor, v: Tensor) -> Tensor:
"""Liner attention implemenation from [Transformers are RNNs: Fast
Autoregressive Transformers with Linear Attention]
(https://arxiv.org/abs/2006.16236)
"""
# b: batch size
# s: seq length
# l: target sequence length
# m: used to refer indistinctively to s or l
# h: number of attention heads,
# d and e: head_dim
q, k = (
nn.functional.elu(q) + 1,
nn.functional.elu(k) + 1,
)
# Term within the summation in the numerator of their Eq 5
scores = einsum("b h s e, b h l d -> b h e d", k, v)
# The denominator in their Eq 5
z = 1 / (torch.einsum("b h m d, b h d -> b h m", q, k.sum(dim=2)) + 1e-6)
# Their Eq 5
attn_output = torch.einsum("b h m d, b h e d, b h m -> b h m d", q, scores, z)
return attn_output
class SDPBackend(Enum):
FLASH: int = 0
MEM_EFFICIENT: int = 1
......@@ -78,7 +133,7 @@ def _flash_kernel_setup(enabled_flash_backends: List[SDPBackend]) -> ContextMana
), "optimized kernels can only be used if CUDA is available."
warnings.warn(
"Note that FlashAttention This function is beta and subject to change",
"Note that FlashAttention is beta and subject to change",
RuntimeWarning,
)
......@@ -103,7 +158,8 @@ class MultiHeadedAttention(nn.Module):
use_bias: bool,
dropout: float,
query_dim: Optional[int] = None,
use_flash: Optional[bool] = False,
use_linear_attention: bool = False,
use_flash_attention: Optional[bool] = False,
enabled_flash_backends: Optional[List[SDPBackend]] = [
SDPBackend.FLASH,
SDPBackend.MEM_EFFICIENT,
......@@ -113,14 +169,14 @@ class MultiHeadedAttention(nn.Module):
assert input_dim % n_heads == 0, "'input_dim' must be divisible by 'n_heads'"
self.use_flash = use_flash
self.use_linear_attention = use_linear_attention
self.use_flash_attention = use_flash_attention
self.enabled_flash_backends = enabled_flash_backends
self.head_dim = input_dim // n_heads
self.n_heads = n_heads
self.dropout_p = dropout
self.dropout = nn.Dropout(dropout)
self.dropout = dropout
query_dim = query_dim if query_dim is not None else input_dim
self.q_proj = nn.Linear(query_dim, input_dim, bias=use_bias)
......@@ -135,7 +191,7 @@ class MultiHeadedAttention(nn.Module):
# l: target sequence length
# m: used to refer indistinctively to s or l
# h: number of attention heads,
# d: head_dim
# d and e: head_dim
q = self.q_proj(X_Q)
X_KV = X_KV if X_KV is not None else X_Q
k, v = self.kv_proj(X_KV).chunk(2, dim=-1)
......@@ -144,103 +200,25 @@ class MultiHeadedAttention(nn.Module):
(q, k, v),
)
if self.use_flash:
if self.use_flash_attention:
with _flash_kernel_setup(self.enabled_flash_backends):
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout_p if self.training else 0,
dropout_p=self.dropout if self.training else 0,
is_causal=False,
)
output = einops.rearrange(
attn_output, "b h s d -> b s (h d)", h=self.n_heads
)
self.attn_weights: Optional[Tensor] = None
elif self.use_linear_attention:
attn_output = _linear_attention(q, k, v)
self.attn_weights = None
else:
scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(
self.head_dim
)
attn_weights = scores.softmax(dim=-1)
self.attn_weights = attn_weights
attn_weights = self.dropout(attn_weights)
attn_output = einsum("b h s l, b h l d -> b h s d", attn_weights, v)
output = einops.rearrange(
attn_output, "b h s d -> b s (h d)", h=self.n_heads
self.attn_weights, attn_output = _standard_attention(
q, k, v, self.head_dim, self.dropout
)
if self.out_proj is not None:
output = self.out_proj(output)
return output
class LinearAttention(nn.Module):
"""Linear Attention implementation from [Transformers are RNNs: Fast
Autoregressive Transformers with Linear Attention]
(https://arxiv.org/abs/2006.16236)
"""
def __init__(
self,
input_dim: int,
n_heads: int,
use_bias: bool,
dropout: float,
query_dim: Optional[int] = None,
):
super(LinearAttention, self).__init__()
assert input_dim % n_heads == 0, "'input_dim' must be divisible by 'n_heads'"
self.head_dim = input_dim // n_heads
self.n_heads = n_heads
self.dropout = nn.Dropout(dropout)
self.input_dim = input_dim
self.use_bias = use_bias
self.query_dim = query_dim
query_dim = query_dim if query_dim is not None else input_dim
self.q_proj = nn.Linear(query_dim, input_dim, bias=use_bias)
self.kv_proj = nn.Linear(input_dim, input_dim * 2, bias=use_bias)
self.out_proj = (
nn.Linear(input_dim, query_dim, bias=use_bias) if n_heads > 1 else None
)
def forward(self, X_Q: Tensor, X_KV: Optional[Tensor] = None) -> Tensor:
# b: batch size
# s: seq length
# l: target seq length
# m: used to refer indistinctively to s or l
# h: number of attention heads,
# d and e: head_dim
queries = self.q_proj(X_Q)
X_KV = X_KV if X_KV is not None else X_Q
keys, v = self.kv_proj(X_KV).chunk(2, dim=-1)
# Here we use the defaut used in the original implementation: elu() + 1
q, k = (
nn.functional.elu(queries) + 1,
nn.functional.elu(keys) + 1,
)
q, k, v = map(
lambda t: einops.rearrange(t, "b m (h d) -> b h m d", h=self.n_heads),
(q, k, v),
)
# The term within the summation in the numerator of their Eq 5
kv = einsum("b h s e, b h l d -> b h e d", k, v)
# The denomiator in their Eq 5
z = 1 / (torch.einsum("b h m d, b h d -> b h m", q, k.sum(dim=2)) + 1e-6)
# Their Eq 5
attn_output = torch.einsum("b h m d, b h e d, b h m -> b h m d", q, kv, z)
output = einops.rearrange(attn_output, "b h s d -> b s (h d)", h=self.n_heads)
if self.out_proj is not None:
......@@ -307,8 +285,7 @@ class LinearAttentionLinformer(nn.Module):
scores = einsum("b h s d, b h k d -> b h s k", q, k) / math.sqrt(self.head_dim)
attn_weights = scores.softmax(dim=-1)
self.attn_weights = attn_weights
attn_weights = self.dropout(attn_weights)
output = einsum("b h s k, b h k d -> b h s d", attn_weights, v)
output = einsum("b h s k, b h k d -> b h s d", self.dropout(attn_weights), v)
output = einops.rearrange(output, "b h s d -> b s (h d)")
if self.out_proj is not None:
......
......@@ -22,6 +22,9 @@ class TransformerEncoder(nn.Module):
ff_dropout: float,
ff_factor: int,
activation: str,
use_linear_attention: bool,
use_flash_attention: bool,
# enabled_flash_backends,
):
super(TransformerEncoder, self).__init__()
......@@ -30,6 +33,9 @@ class TransformerEncoder(nn.Module):
n_heads,
use_bias,
attn_dropout,
None, # query_dim
use_linear_attention,
use_flash_attention,
)
self.ff = FeedForward(input_dim, ff_dropout, ff_factor, activation)
......
......@@ -92,6 +92,14 @@ class TabTransformer(BaseTabularModelWithAttention):
transformer_activation: str, default = "gelu"
Transformer Encoder activation function. _'tanh'_, _'relu'_,
_'leaky_relu'_, _'gelu'_, _'geglu'_ and _'reglu'_ are supported
use_linear_attention: Boolean, default = False,
Boolean indicating if Linear Attention from [Transformers are RNNs:
Fast Autoregressive Transformers with Linear Attention]
(https://arxiv.org/abs/2006.16236) will be used
use_flash_attention: Boolean, default = False,
Boolean indicating if [Flash Attention]
(https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
will be used
mlp_hidden_dims: List, Optional, default = None
MLP hidden dimensions. If not provided it will default to $[l,
4\times l, 2 \times l]$ where $l$ is the MLP's input dimension
......@@ -158,7 +166,10 @@ class TabTransformer(BaseTabularModelWithAttention):
ff_dropout: float = 0.1,
ff_factor: int = 4,
transformer_activation: str = "gelu",
use_linear_attention: bool = False,
use_flash_attention: bool = False,
mlp_hidden_dims: Optional[List[int]] = None,
# enabled_flash_backends,
mlp_activation: str = "relu",
mlp_dropout: float = 0.1,
mlp_batchnorm: bool = False,
......@@ -190,6 +201,8 @@ class TabTransformer(BaseTabularModelWithAttention):
self.attn_dropout = attn_dropout
self.ff_dropout = ff_dropout
self.transformer_activation = transformer_activation
self.use_linear_attention = use_linear_attention
self.use_flash_attention = use_flash_attention
self.ff_factor = ff_factor
self.mlp_hidden_dims = mlp_hidden_dims
......@@ -222,6 +235,8 @@ class TabTransformer(BaseTabularModelWithAttention):
ff_dropout,
ff_factor,
transformer_activation,
use_linear_attention,
use_flash_attention,
),
)
......@@ -297,6 +312,12 @@ class TabTransformer(BaseTabularModelWithAttention):
batch size, $H$ is the number of attention heads and $F$ is the
number of features/columns in the dataset
"""
if self.use_flash_attention or self.use_linear_attention:
raise ValueError(
"Extraction of the attention weights is not supported for "
"linear or flash attention"
)
return [blk.attn.attn_weights for blk in self.encoder]
def _compute_attn_output_dim(self) -> int:
......
......@@ -85,6 +85,8 @@ class Transformer(nn.Module):
ff_dropout: float = 0.1,
ff_factor: int = 4,
activation: str = "gelu",
use_linear_attention: bool = False,
use_flash_attention: bool = False,
with_cls_token: bool = False,
*, # from here on pos encoding args
with_pos_encoding: bool = True,
......@@ -101,6 +103,8 @@ class Transformer(nn.Module):
self.ff_dropout = ff_dropout
self.ff_factor = ff_factor
self.activation = activation
self.use_linear_attention = use_linear_attention
self.use_flash_attention = use_flash_attention
self.with_cls_token = with_cls_token
self.with_pos_encoding = with_pos_encoding
self.pos_encoding_dropout = pos_encoding_dropout
......@@ -131,6 +135,8 @@ class Transformer(nn.Module):
ff_dropout,
ff_factor,
activation,
use_linear_attention,
use_flash_attention,
),
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册